From a29d870883fae30fcdee88c274bab9c8cb1c5cf6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 May 2025 09:07:20 -0700 Subject: [PATCH 01/57] [webgpu] enable f16 on vulkan/nvidia GPUs (#24782) ### Description enable f16 on vulkan/nvidia GPUs ### Motivation and Context --- .../external/onnxruntime_external_deps.cmake | 47 +++++++++++-------- .../dawn_force_enable_f16_nvidia_vulkan.patch | 19 ++++++++ 2 files changed, 46 insertions(+), 20 deletions(-) create mode 100644 cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index d967e806eb5a3..c6a5960d9b9da 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -708,30 +708,37 @@ if (onnxruntime_USE_WEBGPU) EXCLUDE_FROM_ALL ) else() + set(ONNXRUNTIME_Dawn_PATCH_COMMAND + # The dawn.patch contains the following changes: + # + # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor + # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But + # in emwgpu implementation, the buffer destroy won't happen. This change adds a destructor to the buffer class + # to destroy the buffer when the refcount is 0 for non-external buffers. + # + # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files + # https://github.com/microsoft/onnxruntime/pull/23729 + # + # - (private) Reduce unsafe buffer usage warning in aligned_storage.h + # https://github.com/microsoft/onnxruntime/pull/24308 + # The patch disables the UNSAFE_BUFFER_USAGE warning around the AlignedStorage struct in aligned_storage.h. This is done + # by using TINT_BEGIN_DISABLE_WARNING and TINT_END_DISABLE_WARNING macros, which helps in warnings related to unsafe buffer usage + # usage when compiling the code, making the build process cleaner and faster. + # + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch && + + # The dawn_force_enable_f16_nvidia_vulkan.patch contains the following changes: + # + # - (private) Force enable f16 support for NVIDIA Vulkan + # Dawn disabled f16 support for NVIDIA Vulkan by default because of crashes in f16 CTS tests (crbug.com/tint/2164). + # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch) + onnxruntime_fetchcontent_declare( dawn URL ${DEP_URL_dawn} URL_HASH SHA1=${DEP_SHA1_dawn} - # # All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now. - # # if we need to apply patches in the future, we can uncomment the following line. - # - # The dawn.patch contains the following changes: - # - # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor - # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But - # in emwgpu implementation, the buffer destroy won't happen. This change adds a destructor to the buffer class - # to destroy the buffer when the refcount is 0 for non-external buffers. - # - # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files - # https://github.com/microsoft/onnxruntime/pull/23729 - # - # - (private) Reduce unsafe buffer usage warning in aligned_storage.h - # https://github.com/microsoft/onnxruntime/pull/24308 - # The patch disables the UNSAFE_BUFFER_USAGE warning around the AlignedStorage struct in aligned_storage.h. This is done - # by using TINT_BEGIN_DISABLE_WARNING and TINT_END_DISABLE_WARNING macros, which helps in warnings related to unsafe buffer usage - # usage when compiling the code, making the build process cleaner and faster. - # - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch + PATCH_COMMAND ${ONNXRUNTIME_Dawn_PATCH_COMMAND} EXCLUDE_FROM_ALL ) endif() diff --git a/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch b/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch new file mode 100644 index 0000000000000..2d999a456fdec --- /dev/null +++ b/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch @@ -0,0 +1,19 @@ +diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +index 158f10764c..a324c101ed 100644 +--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp ++++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +@@ -269,11 +269,9 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() { + mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) { +- // TODO(crbug.com/tint/2164): Investigate crashes in f16 CTS tests to enable on NVIDIA. +- if (!gpu_info::IsNvidia(GetVendorId())) { +- EnableFeature(Feature::ShaderF16); +- shaderF16Enabled = true; +- } ++ // ONNX Runtime Patch: enable shaderF16 on all devices. ++ EnableFeature(Feature::ShaderF16); ++ shaderF16Enabled = true; + } + + if (mDeviceInfo.HasExt(DeviceExt::DrawIndirectCount) && From a1693667c0fb4819bd1b20c4910d58a883a3768f Mon Sep 17 00:00:00 2001 From: Ashrit Shetty Date: Fri, 16 May 2025 09:29:13 -0700 Subject: [PATCH 02/57] Telemetry field to indicate debugger is attached (#24777) ### Description This commit adds a telemetry field to indicate if a debugger is attached to the process. ### Motivation and Context This is useful for ignoring events coming from processes being debugged. --- onnxruntime/core/platform/windows/telemetry.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 0775e19c5654b..2385bae65d491 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -183,6 +183,7 @@ void WindowsTelemetry::LogProcessInfo() const { // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(ORT_VERSION, "runtimeVersion"), + TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), TraceLoggingBool(isRedist, "isRedist")); process_info_logged = true; From b660e08be843007e150ff3a351c48bd364ad13a1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 May 2025 11:32:55 -0700 Subject: [PATCH 03/57] [doc] Update README.md for Node.js binding (#24783) ### Description mark Linux x64 supports webgpu ### Motivation and Context --- js/node/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/node/README.md b/js/node/README.md index b8414546c4729..272ec6ef561c2 100644 --- a/js/node/README.md +++ b/js/node/README.md @@ -27,13 +27,13 @@ The following table lists the supported versions of ONNX Runtime Node.js binding | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 | | ------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | -| WebGPU | ✔️ \[1] | ✔️ \[1] | ❌ \[2] | ❌ \[2] | ✔️ \[1] | ✔️ \[1] | +| WebGPU | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | ❌ \[2] | ✔️ \[1] | ✔️ \[1] | | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ | | CUDA | ❌ | ❌ | ✔️\[3] | ❌ | ❌ | ❌ | | CoreML | ❌ | ❌ | ❌ | ❌ | ✔️ | ✔️ | - \[1]: WebGPU support is currently experimental. -- \[2]: WebGPU support is not available on Linux x64 and arm64 yet in the pre-built binaries. +- \[2]: WebGPU support is not available on Linux arm64 yet in the pre-built binaries. - \[3]: CUDA v12. See [CUDA EP Installation](#cuda-ep-installation) for details. To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://onnxruntime.ai/docs/build/inferencing.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally. From 77afa69c472659f8ae87a975faaef74dc391e589 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 May 2025 11:33:11 -0700 Subject: [PATCH 04/57] add "enable_generic_interface" build flag to node package (#24763) ### Description Add '--enable_generic_interface' build flag to the node package for Windows (both x64 and arm64) builds. ### Motivation and Context --- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index cf213c47195c4..3006eebd2d3b5 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -155,7 +155,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-onnxruntime-nodejs-win-x64' StageName: 'Windows_Nodejs_Packaging_x64' - BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" --enable_generic_interface BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'x64' @@ -167,7 +167,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-onnxruntime-nodejs-win-arm64' StageName: 'Windows_Nodejs_Packaging_arm64' - BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" --enable_generic_interface BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' From 323b5f4cda0aa2f944dab04b1ae8d89d631fdaa0 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 16 May 2025 13:18:18 -0700 Subject: [PATCH 05/57] Disable a test for QNN to unblock the build pipeline (#24791) ### Description Disable a test for QNN to unblock the build pipeline. Should be caused by a combination of PR changes. --- onnxruntime/test/providers/cpu/nn/pool_op_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 36150d03a7d36..8edbd417544c4 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -229,7 +229,9 @@ TEST(PoolTest, MaxPool1D_case2) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + + // QNN test failed. Caused by a combination of most recent changes, will fix it + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); } TEST(PoolTest, MaxPool1D_case3) { From 1025905a80a55ae50db612d5ce95b6d12462dc2a Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 16 May 2025 14:12:51 -0700 Subject: [PATCH 06/57] Fix nightly packaging pipelines (#24789) --- .../Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj | 2 +- .../github/azure-pipelines/nuget/templates/dml-vs-2022.yml | 2 +- .../github/azure-pipelines/stages/py-win-gpu-stage.yml | 1 - .../c-api-artifacts-package-and-publish-steps-windows.yml | 6 ++++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index ee3c8c69aa2ae..54b9925710296 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -71,7 +71,7 @@ Include="$(NativeBuildOutputDir)\onnxruntime_providers_*.dll; $(NativeBuildOutputDir)\onnxruntime_providers_*.pdb; $(NativeBuildOutputDir)\custom_op_library*.dll; - $(NativeBuildOutputDir)\example_plugin_ep.dll"> + $(NativeBuildOutputDir)\example_plugin_ep*.dll"> PreserveNewest false diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 2a09eba776353..a87b85eaac256 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -151,7 +151,7 @@ stages: - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: - - template: publish-symbolrequestprod-api.yml + - template: ../../templates/publish-symbolrequestprod-api.yml parameters: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 0a88391dd4ad6..de0a8f10b82be 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -140,7 +140,6 @@ stages: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 includePublicSymbolServer: true - symbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' symbolsArtifactName: onnxruntime_gpu_win_x64_${{ parameters.PYTHON_VERSION }} symbolsVersion: $(Build.BuildId) symbolProject: 'ONNX Runtime' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index adf9c91e602a0..72343613d6b26 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -119,6 +119,12 @@ steps: DoEsrp: ${{parameters.DoEsrp}} Pattern: '*.dll,*.exe' + - task: DeleteFiles@1 + displayName: 'Delete CodeSignSummary*.md' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' + Contents: 'CodeSignSummary*.md' + - task: ArchiveFiles@2 inputs: rootFolderOrFile: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' From 8983424d9a8d0a39d065b0e353d6fd3f2b2a638c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 May 2025 14:48:57 -0700 Subject: [PATCH 07/57] [CUDA] Upgrade cutlass to 3.9.2 (#24794) ### Description Upgrade cutlass to 3.9.2 ### Motivation and Context To work on new features. --- cmake/deps.txt | 2 +- .../bert/cutlass_fmha/fmha_launch_template.h | 9 ++++---- .../cuda/bert/cutlass_fmha/kernel_forward.h | 22 +++++++++---------- .../epilogue/thread/fused_activations.h | 21 ------------------ 4 files changed, 17 insertions(+), 37 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index a10bede254007..e08f4a30ccaeb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -52,7 +52,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18e re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.9.2.zip;b7f8dc4a879765127ce31dfeabd31c556c80ec79 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 8d8f735e3ed34..5aeda0f74e92b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -142,15 +142,16 @@ template ; + typename Attention::Params p; { // set parameters p.query_ptr = const_cast(reinterpret_cast(params.query)); p.key_ptr = const_cast(reinterpret_cast(params.key)); p.value_ptr = const_cast(reinterpret_cast(params.value)); p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); - p.seqstart_q_ptr = params.seqstart_q_ptr; - p.seqstart_k_ptr = params.seqstart_k_ptr; - p.seqlen_k_ptr = params.seqlen_k_ptr; + p.seqstart_q_ptr = const_cast(params.seqstart_q_ptr); + p.seqstart_k_ptr = const_cast(params.seqstart_k_ptr); + p.seqlen_k_ptr = const_cast(params.seqlen_k_ptr); p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward p.output_ptr = reinterpret_cast(params.output); @@ -260,7 +261,7 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { params.v_head_size % AlignedAK::kAlignmentV == 0; DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { - LaunchCutlassFmha(params); + LaunchCutlassFmha(params); })); #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index f35d6c2e6c8dc..41691d823f528 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,6 +38,7 @@ #include #include +#include #include #include "cutlass/fast_math.h" @@ -71,8 +72,6 @@ #include "41_fused_multi_head_attention/gemm_kernel_utils.h" #include "41_fused_multi_head_attention/transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -174,9 +173,10 @@ struct AttentionKernel { scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] - const int32_t* seqstart_q_ptr = nullptr; - const int32_t* seqstart_k_ptr = nullptr; - const int32_t* seqlen_k_ptr = nullptr; + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; // Output tensors @@ -1105,15 +1105,15 @@ struct AttentionKernel { using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, ElementCompute, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -1121,7 +1121,7 @@ struct AttentionKernel { typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -1139,7 +1139,7 @@ struct AttentionKernel { int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h index da8cb6d294efd..644caa950e5a4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h @@ -67,27 +67,6 @@ __forceinline__ __device__ float tanh_opt(float x) { #endif } -///////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct GELU_taylor { - static bool const kIsHeavy = true; - - CUTLASS_DEVICE - float operator()(float const& z) const { - float k0 = static_cast(0.7978845608028654); - float k1 = static_cast(0.044715); - - return static_cast( - cutlass::constants::half() * z * - (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } -}; - } // namespace thread } // namespace epilogue } // namespace cutlass From 9681fe68a660bbb44acae70d54057e1657411c0d Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 16 May 2025 15:37:42 -0700 Subject: [PATCH 08/57] [C# docs] fix file-scoped namespace not supported in C# 8 (#24793) ### Description Changes the namespace declaration from ```C# namespace Microsoft.ML.OnnxRuntime.CompileApi; // Code ``` to ```C# namespace Microsoft.ML.OnnxRuntime.CompileApi { // Code } ``` ### Motivation and Context file-scoped namespaces are not supported in C# 8.0, which results in an error in our documentation publishing: https://github.com/microsoft/onnxruntime/actions/workflows/publish-csharp-apidocs.yml --- .../NativeCompileApiMethods.shared.cs | 283 +++++++++--------- 1 file changed, 142 insertions(+), 141 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 3a87f87d124e9..602bcc6caf7f8 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -1,152 +1,153 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace Microsoft.ML.OnnxRuntime.CompileApi; - using System; using System.Runtime.InteropServices; -// NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi -// See onnxruntime/core/session/compile_api.cc. -[StructLayout(LayoutKind.Sequential)] -public struct OrtCompileApi -{ - public IntPtr ReleaseModelCompilationOptions; - public IntPtr CreateModelCompilationOptionsFromSessionOptions; - public IntPtr ModelCompilationOptions_SetInputModelPath; - public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; - public IntPtr ModelCompilationOptions_SetOutputModelPath; - public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; - public IntPtr ModelCompilationOptions_SetOutputModelBuffer; - public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; - public IntPtr CompileModel; -} - -internal class NativeMethods +namespace Microsoft.ML.OnnxRuntime.CompileApi { - private static OrtCompileApi _compileApi; - - // - // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. - // - // We populate the static members in the constructor of this class. - // - // The C# code will call the C++ API through the delegate instances in the static members. - // - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); - public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( - IntPtr /* const OrtEnv* */ env, - IntPtr /* const OrtSessionOptions* */ sessionOptions, - out IntPtr /* OrtModelCompilationOptions** */ outOptions); - public DOrtCreateModelCompilationOptionsFromSessionOptions - OrtCreateModelCompilationOptionsFromSessionOptions; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ inputModelPath); - public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const void* */ inputModelData, - UIntPtr /* size_t */ inputModelDataSize); - public DOrtModelCompilationOptions_SetInputModelFromBuffer - OrtModelCompilationOptions_SetInputModelFromBuffer; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ outputModelPath); - public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, - UIntPtr /* size_t */ externalInitializerSizeThreshold); - public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile - OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( - IntPtr /* OrtModelCompilationOptions* */ options, - IntPtr /* OrtAllocator* */ allocator, - ref IntPtr /* void** */ outputModelBufferPtr, - ref UIntPtr /* size_t* */ outputModelBufferSizePtr); - public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( - IntPtr /* OrtModelCompilationOptions* */ options, - bool embedEpContextInModel); - public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( - IntPtr /* const OrtEnv* */ env, - IntPtr /* const OrtModelCompilationOptions* */ modelOptions); - public DOrtCompileModel OrtCompileModel; - - internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + // NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi + // See onnxruntime/core/session/compile_api.cc. + [StructLayout(LayoutKind.Sequential)] + public struct OrtCompileApi { + public IntPtr ReleaseModelCompilationOptions; + public IntPtr CreateModelCompilationOptionsFromSessionOptions; + public IntPtr ModelCompilationOptions_SetInputModelPath; + public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; + public IntPtr ModelCompilationOptions_SetOutputModelPath; + public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; + public IntPtr ModelCompilationOptions_SetOutputModelBuffer; + public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; + public IntPtr CompileModel; + } -#if NETSTANDARD2_0 - IntPtr compileApiPtr = getCompileApi(); - _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); -#else - _compileApi = (OrtCompileApi)getCompileApi(); -#endif - - OrtReleaseModelCompilationOptions = - (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.ReleaseModelCompilationOptions, - typeof(DOrtReleaseModelCompilationOptions)); - - OrtCreateModelCompilationOptionsFromSessionOptions = - (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.CreateModelCompilationOptionsFromSessionOptions, - typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); - - OrtModelCompilationOptions_SetInputModelPath = - (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelPath, - typeof(DOrtModelCompilationOptions_SetInputModelPath)); - - OrtModelCompilationOptions_SetInputModelFromBuffer = - (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, - typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); - - OrtModelCompilationOptions_SetOutputModelPath = - (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelPath, - typeof(DOrtModelCompilationOptions_SetOutputModelPath)); - - OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = - (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, - typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); - - OrtModelCompilationOptions_SetOutputModelBuffer = - (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelBuffer, - typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); - - OrtModelCompilationOptions_SetEpContextEmbedMode = - (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, - typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); - - OrtCompileModel = - (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( - _compileApi.CompileModel, - typeof(DOrtCompileModel)); + internal class NativeMethods + { + private static OrtCompileApi _compileApi; + + // + // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. + // + // We populate the static members in the constructor of this class. + // + // The C# code will call the C++ API through the delegate instances in the static members. + // + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); + public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtSessionOptions* */ sessionOptions, + out IntPtr /* OrtModelCompilationOptions** */ outOptions); + public DOrtCreateModelCompilationOptionsFromSessionOptions + OrtCreateModelCompilationOptionsFromSessionOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ inputModelPath); + public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const void* */ inputModelData, + UIntPtr /* size_t */ inputModelDataSize); + public DOrtModelCompilationOptions_SetInputModelFromBuffer + OrtModelCompilationOptions_SetInputModelFromBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputModelPath); + public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, + UIntPtr /* size_t */ externalInitializerSizeThreshold); + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* OrtAllocator* */ allocator, + ref IntPtr /* void** */ outputModelBufferPtr, + ref UIntPtr /* size_t* */ outputModelBufferSizePtr); + public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( + IntPtr /* OrtModelCompilationOptions* */ options, + bool embedEpContextInModel); + public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtModelCompilationOptions* */ modelOptions); + public DOrtCompileModel OrtCompileModel; + + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + { + + #if NETSTANDARD2_0 + IntPtr compileApiPtr = getCompileApi(); + _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); + #else + _compileApi = (OrtCompileApi)getCompileApi(); + #endif + + OrtReleaseModelCompilationOptions = + (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.ReleaseModelCompilationOptions, + typeof(DOrtReleaseModelCompilationOptions)); + + OrtCreateModelCompilationOptionsFromSessionOptions = + (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.CreateModelCompilationOptionsFromSessionOptions, + typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); + + OrtModelCompilationOptions_SetInputModelPath = + (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelPath, + typeof(DOrtModelCompilationOptions_SetInputModelPath)); + + OrtModelCompilationOptions_SetInputModelFromBuffer = + (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); + + OrtModelCompilationOptions_SetOutputModelPath = + (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelPath, + typeof(DOrtModelCompilationOptions_SetOutputModelPath)); + + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); + + OrtModelCompilationOptions_SetOutputModelBuffer = + (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); + + OrtModelCompilationOptions_SetEpContextEmbedMode = + (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); + + OrtCompileModel = + (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.CompileModel, + typeof(DOrtCompileModel)); + } } } From 8c76d5c877dc04489d3120bb7164355bc8d7fc56 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 17 May 2025 17:09:43 -0700 Subject: [PATCH 09/57] [build] fix build break caused by inconsistent version for dlpack (#24802) ### Description Currently some required ADO pipeline fails because of version mismatch between vcpkg build and non vcpkg build. This PR fixes the failed builds. ### Motivation and Context --- cgmanifests/cgmanifest.json | 12 +----------- cmake/deps.txt | 2 +- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index f29857a231eb9..bf889e9fb61a8 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", + "commitHash": "5c210da409e7f1e51ddf445134a4376fdbd70d7d", "repositoryUrl": "https://github.com/dmlc/dlpack.git" } } @@ -316,16 +316,6 @@ "comments": "gtest-ios-framework" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - }, - "comments": "dlpack" - } - }, { "component": { "Type": "other", diff --git a/cmake/deps.txt b/cmake/deps.txt index e08f4a30ccaeb..728241840f723 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -16,7 +16,7 @@ abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240722.0.zip coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 -dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 +dlpack;https://github.com/dmlc/dlpack/archive/5c210da409e7f1e51ddf445134a4376fdbd70d7d.zip;e499c86e4e5c5268a87661d7ea39c27fae10907c # This Eigen commit id matches the eigen archive being consumed from https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip # prior to the 3.4.1 RC changing the bits and invalidating the hash. # it contains changes on top of 3.4.0 which are required to fix build issues. From 1b5628a265d15a99d31b6ad9db7c8e342f416d8a Mon Sep 17 00:00:00 2001 From: Hector Li Date: Sat, 17 May 2025 22:07:04 -0700 Subject: [PATCH 10/57] Validate ep.context_file_path option (#24797) ### Description Validate ep.context_file_path option, make sure it failed if it's not valid file path --- .../core/framework/graph_partitioner.cc | 2 +- .../test/providers/qnn/qnn_ep_context_test.cc | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 8ed5eeaa8d44f..b39d0dbd25f8d 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -770,7 +770,7 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ bool allow_overwrite_output_model = false) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; - if (!context_cache_path.has_filename()) { + if (!(context_cache_path.has_filename() && context_cache_path.extension() != "")) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8d840b1a3d45f..aeb10b8b4294b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -681,6 +681,49 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) { } } +// Set ep.context_file_path to invalid file path, check the error message +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected2) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/invalid_file"; + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } +} + // Create session 1 to generate context binary file // Create session 2 to do same thing, make sure session 2 failed because file exist already // Make sure no new file over write from session 2 From c4fd8764131bfa0645f230b908edaf8c11ef21b3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 18 May 2025 10:21:45 -0700 Subject: [PATCH 11/57] [CI] re-enable wasm CPU tests (#24801) ### Description 1. re-enable wasm CPU tests. It was originally enabled but was later disabled in a change that treat wasm build as cross-compiling. 2. Use build.py to populate the environment variables. ### Motivation and Context --- .../linux-wasm-ci-build-and-test-workflow.yml | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index e53626d879dd1..d74d9e9a4f0bf 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -57,29 +57,38 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Install EMSDK + - name: Build (simd + threads) run: | - set -ex - cd ${{ github.workspace }}/cmake/external/emsdk - ./emsdk install 4.0.4 - ./emsdk activate 4.0.4 + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --skip_tests + working-directory: ${{ github.workspace }} + + - name: Test (Node.js) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --test + working-directory: ${{ github.workspace }} - - name: Build and test (browser) (simd + threads) + - name: Test (browser) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing \ - --wasm_run_tests_in_browser + --wasm_run_tests_in_browser \ + --test + working-directory: ${{ github.workspace }} - name: Build (simd + threads + JSEP) if: ${{ inputs.build_jsep == true }} run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing_jsep \ @@ -87,13 +96,11 @@ jobs: --use_webnn \ --target onnxruntime_webassembly \ --skip_tests + working-directory: ${{ github.workspace }} - name: Build (simd + threads + WebGPU experimental) if: ${{ inputs.build_webgpu == true }} run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing_webgpu \ @@ -102,6 +109,7 @@ jobs: --use_webnn \ --target onnxruntime_webassembly \ --skip_tests + working-directory: ${{ github.workspace }} - name: Create Artifacts if: ${{ inputs.skip_publish != true }} From 5c97cb11ae22b9d0bc28b22db847d82ca01a1c85 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 19 May 2025 23:41:21 +0800 Subject: [PATCH 12/57] [WebNN] Support ConvInteger op (#24803) WebNN doesn't provide a dedicated op for `ConvInteger`, this PR supports `ConvInteger` op by decomposing it into `DequantizeLinear x, w -> Conv -> Cast (to int32)`. BTW, adds `ConvInteger` to layout sensitive op list for layout transformation when the preferred layout is NHWC. --- js/web/docs/webnn-operators.md | 1 + .../onnx_transpose_optimization.cc | 2 +- .../core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/conv_op_builder.cc | 97 +++++++++++++++---- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 981a684154df1..d9a030f320c6c 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,6 +25,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | | | Conv | ai.onnx(7-10, 11+) | conv2d | Only supports 3-D or 4-D input and 'W' (weight) | +| ConvInteger | ai.onnx(10+) | cast, conv2d, dequantizeLinear | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | Only supports 3-D or 4-D input and 'W' (weight) | | Cos | ai.onnx(7+) | cos | | | CumSum | ai.onnx(11-13, 14+) | cumulativeSum | 'axis' input should be a constant | diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 93c7efc9ca167..ac128011c0b9f 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -3308,7 +3308,7 @@ const std::unordered_set& GetLayoutSensitiveOps() { "BatchNormalization", "InstanceNormalization", // convolutions - "Conv", "QLinearConv", "ConvTranspose", + "Conv", "ConvInteger", "QLinearConv", "ConvTranspose", // pooling "AveragePool", "LpPool", "MaxPool", "MaxUnpool", diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 6556c293f81bf..072273a137557 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -203,6 +203,7 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe // Some ONNX ops are supported by decomposed WebNN ops. const std::map> decomposed_op_map = { + {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, {"GroupQueryAttention", {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", "softmax", "transpose", "where"}}, @@ -228,7 +229,6 @@ const std::map op_map = { {"Clip", "clamp"}, {"Concat", "concat"}, {"Conv", "conv2d"}, - {"ConvInteger", "conv2dInteger"}, {"ConvTranspose", "convTranspose2d"}, {"Cos", "cos"}, {"CumSum", "cumulativeSum"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 1924c3cb5e698..b9383a63fe307 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -30,6 +30,8 @@ class ConvOpBuilder : public BaseOpBuilder { const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -52,18 +54,19 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const logging::Logger& logger) { NodeAttrHelper helper(node); const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); // Add Padding. AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); std::vector pads_out; - if (node.OpType() == "Conv" || node.OpType() == "ConvInteger") { + if (op_type == "Conv" || op_type == "ConvInteger") { // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); pads = pads_out; } - } else if (node.OpType() == "ConvTranspose") { + } else if (op_type == "ConvTranspose") { std::vector output_shape = helper.Get("output_shape", std::vector{-1, -1}); // Appending 1's if it is ConvTranspose 1d and output shape is provided. if (output_shape.size() == 1 && is_conv1d && output_shape[0] != -1) { @@ -103,7 +106,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("padding", emscripten::val::array(GetNarrowedIntfromInt64(padding))); // Add bias if present. - if (input_defs.size() > 2) { + if (input_defs.size() > 2 && op_type != "ConvInteger") { options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); } @@ -219,6 +222,8 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; const bool is_constant_weight = Contains(initializers, weight_name); + + emscripten::val common_options = emscripten::val::object(); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. @@ -230,7 +235,9 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N input_shape.push_back(1); } std::vector new_shape = GetNarrowedIntfromInt64(input_shape); - input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + common_options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", input, + emscripten::val::array(new_shape), common_options); weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. strides.resize(2, 1); // Ensure 2D by appending 1's if needed. @@ -277,16 +284,14 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. std::vector new_shape = GetNarrowedIntfromInt64(weight_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_filter"); + common_options.set("label", node.Name() + "_reshape_filter"); filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape), - reshape_options); + common_options); } } - emscripten::val transpose_options = emscripten::val::object(); if (is_nhwc && !is_constant_weight) { // For NHWC preferred layout, if the weight is input: // - Transpose it from iohw -> ohwi for convTranspose. @@ -298,6 +303,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { perm = {0, 2, 3, 1}; // L_0231 } + emscripten::val transpose_options = emscripten::val::object(); transpose_options.set("permutation", emscripten::val::array(perm)); transpose_options.set("label", node.Name() + "_transpose_filter"); filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); @@ -306,20 +312,48 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (op_type == "Conv") { output = model_builder.GetBuilder().call("conv2d", input, filter, options); } else if (op_type == "ConvInteger") { - emscripten::val x_zero_point = emscripten::val::null(); - emscripten::val w_zero_point = emscripten::val::null(); - if (input_defs.size() >= 3) { + // WebNN doesn't provide a dedicated op for ConvInteger, it can be simply decomposed by + // DequantizeLinear x, w -> Conv -> Cast (to int32) + int32_t x_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], x_type, logger), "Cannot get data type of input x"); + + emscripten::val x_zero_point, w_zero_point, x_scale, w_scale; + if (TensorExists(input_defs, 2)) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - if (input_defs.size() >= 4) { + + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. + // So the x_scale must be a scalar too. + x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); + // Dequantize x to Float32 + common_options.set("label", node.Name() + "_dequantized_x"); + input = model_builder.GetBuilder().call("dequantizeLinear", input, x_scale, x_zero_point, + common_options); + + if (TensorExists(input_defs, 3)) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); + std::vector w_zero_point_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[3], w_zero_point_shape, logger), "Cannot get shape of w_zero_point"); + w_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, + GetNarrowedIntfromInt64(w_zero_point_shape)); } else { - w_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + w_zero_point = model_builder.CreateOrGetConstant(x_type, 0); + w_scale = x_scale; } - output = model_builder.GetBuilder().call("conv2dInteger", - input, x_zero_point, filter, w_zero_point, options); + // Dequantize w to Float32 + common_options.set("label", node.Name() + "_dequantized_w"); + filter = model_builder.GetBuilder().call("dequantizeLinear", filter, w_scale, w_zero_point, + common_options); + // Conv with dequantized x and w + options.set("label", node.Name() + "_conv_dequantized_inputs"); + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + + // Cast the result to int32 + common_options.set("label", node.Name() + "_cast_output"); + output = model_builder.GetBuilder().call("cast", output, emscripten::val("int32"), common_options); } else { output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); } @@ -330,12 +364,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); std::vector new_shape = GetNarrowedIntfromInt64(output_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_output"); + common_options.set("label", node.Name() + "_reshape_output"); output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape), - reshape_options); + common_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -410,7 +443,31 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + if (op_type == "ConvInteger") { + // The first decomposed op of ConvInteger is DequantizeLinear, and so + // we only need to ensure it supports the input0_type. + return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); + } else { + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + } +} + +bool ConvOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output = *node.OutputDefs()[0]; + const std::string_view op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) { + return false; + } + + if (op_type == "ConvInteger") { + // The last decomposed op of ConvInteger is Cast, and so + // we only need to ensure it supports the output_type. + return IsDataTypeSupportedByOp("Cast", output_type, wnn_limits, "output", "Output", logger); + } else { + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); + } } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { From 5915bc840796ecada1ff9fe88972be8b76344434 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 19 May 2025 13:22:53 -0700 Subject: [PATCH 13/57] enable using vcpkg for Dawn (#24699) ### Description Enable vcpkg for webgpu ### Motivation and Context --- .../macos-ci-build-and-test-workflow.yml | 2 +- .github/workflows/windows_webgpu.yml | 21 +- cmake/CMakeLists.txt | 24 +++ .../external/onnxruntime_external_deps.cmake | 182 +++++++++--------- cmake/onnxruntime.cmake | 4 +- cmake/onnxruntime_java.cmake | 34 ++++ cmake/onnxruntime_nodejs.cmake | 9 +- cmake/onnxruntime_providers_webgpu.cmake | 30 ++- cmake/onnxruntime_python.cmake | 30 +++ cmake/vcpkg-ports/dawn/dawn.patch | 59 ++++++ .../dawn_force_enable_f16_nvidia_vulkan.patch | 19 ++ .../dawn/dawn_vcpkg_integration.patch | 125 ++++++++++++ cmake/vcpkg-ports/dawn/portfile.cmake | 138 +++++++++++++ cmake/vcpkg-ports/dawn/vcpkg.json | 62 ++++++ cmake/vcpkg.json | 22 +-- .../main/java/ai/onnxruntime/OnnxRuntime.java | 9 +- .../java/ai/onnxruntime/InferenceTest.java | 3 + requirements-dev.txt | 1 + .../github/linux/python/requirements.txt | 2 + .../github/windows/python/requirements.txt | 2 + 20 files changed, 666 insertions(+), 112 deletions(-) create mode 100644 cmake/vcpkg-ports/dawn/dawn.patch create mode 100644 cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch create mode 100644 cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch create mode 100644 cmake/vcpkg-ports/dawn/portfile.cmake create mode 100644 cmake/vcpkg-ports/dawn/vcpkg.json diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index dfe97f8370e99..9e276751bd3d0 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -51,7 +51,7 @@ jobs: --build_objc --build_java --build_wheel - ${{ inputs.use_webgpu && '--use_webgpu' || '' }} + ${{ inputs.use_webgpu && '--use_webgpu --cmake_extra_defines onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON' || '' }} ${{ inputs.use_xnnpack && '--use_xnnpack' || '' }} ${{ inputs.use_coreml && '--use_coreml' || '' }} --use_vcpkg --use_vcpkg_ms_internal_asset_cache diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 999025f560674..70e8ea7e2792f 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -19,6 +19,9 @@ jobs: webgpu_build_x64_RelWithDebInfo: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] timeout-minutes: 300 + strategy: + matrix: + vcpkg_option: [novcpkg, vcpkg] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -107,7 +110,23 @@ jobs: - name: Build and Test shell: pwsh run: | - python.exe ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --build_dir ${{ github.workspace }} --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_nodejs --use_webgpu --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --skip_submodule_sync ` + --build_csharp ` + --parallel ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --enable_onnx_tests ` + --build_nodejs ` + --build_java ` + --use_webgpu ` + ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` + --cmake_extra_defines ` + onnxruntime_BUILD_UNIT_TESTS=ON ` + onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON if ($lastExitCode -ne 0) { exit $lastExitCode } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 121799e16ee97..2451fe9f4008b 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1066,6 +1066,30 @@ endif() if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) + + if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (NOT onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=OFF") + endif() + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_USE_EXTERNAL_DAWN=ON") + endif() + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with a custom dawn source path") + endif() + if (WIN32) + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=ON") + endif() + if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_DAWN_BACKEND_D3D12=OFF") + endif() + if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP=ON") + endif() + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) list(APPEND ORT_PROVIDER_FLAGS -DBUILD_DAWN_MONOLITHIC_LIBRARY=1) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index c6a5960d9b9da..4ed74f1315cb5 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -625,90 +625,97 @@ endif() if (onnxruntime_USE_WEBGPU) - set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) - set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) - set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) - set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) - - if (onnxruntime_USE_EXTERNAL_DAWN) - message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # vcpkg does not support Emscripten yet + find_package(dawn REQUIRED) + else() + # + # Please keep the following in sync with cmake/vcpkg-ports/dawn/portfile.cmake + # + set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) + set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) + set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE) + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) + + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + endif() + else() + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) endif() - else() - # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size - set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) - endif() - - if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) - set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE) - else() - set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) - endif() - - # disable things we don't use - set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) - - set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving - set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. - - # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. - set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) - if (WIN32) - # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. - set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) - set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) - - if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) - message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") - endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) - set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) + set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE) else() - set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) - else() - set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + + # disable things we don't use + set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) + set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) + + set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving + set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. + + # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. + set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) + + if (WIN32) + # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. + set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) + set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) + + if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + endif() + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() - # We are currently always using the D3D12 backend. - set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() - endif() - if (onnxruntime_CUSTOM_DAWN_SRC_PATH) - # use the custom dawn source path if provided - # - # specified as: - # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" - onnxruntime_fetchcontent_declare( - dawn - SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} - EXCLUDE_FROM_ALL - ) - else() - set(ONNXRUNTIME_Dawn_PATCH_COMMAND + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + # use the custom dawn source path if provided + # + # specified as: + # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" + onnxruntime_fetchcontent_declare( + dawn + SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} + EXCLUDE_FROM_ALL + ) + else() + set(ONNXRUNTIME_Dawn_PATCH_COMMAND # The dawn.patch contains the following changes: # # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor @@ -734,16 +741,17 @@ if (onnxruntime_USE_WEBGPU) # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch) - onnxruntime_fetchcontent_declare( - dawn - URL ${DEP_URL_dawn} - URL_HASH SHA1=${DEP_SHA1_dawn} - PATCH_COMMAND ${ONNXRUNTIME_Dawn_PATCH_COMMAND} - EXCLUDE_FROM_ALL - ) - endif() + onnxruntime_fetchcontent_declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + PATCH_COMMAND ${ONNXRUNTIME_Dawn_PATCH_COMMAND} + EXCLUDE_FROM_ALL + ) + endif() - onnxruntime_fetchcontent_makeavailable(dawn) + onnxruntime_fetchcontent_makeavailable(dawn) + endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 1b124e3bb3f74..f6130f8c518a6 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -158,8 +158,8 @@ if(onnxruntime_BUILD_SHARED_LIB) target_link_options(onnxruntime PRIVATE "LINKER:-exported_symbols_list,${SYMBOL_FILE}") set_target_properties(onnxruntime PROPERTIES MACOSX_RPATH TRUE - SKIP_BUILD_RPATH TRUE - INSTALL_RPATH_USE_LINK_PATH FALSE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_NAME_DIR TRUE INSTALL_NAME_DIR @rpath) endif() diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 1227264e595ed..25ceb63df1f19 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -58,6 +58,15 @@ file(GLOB onnxruntime4j_native_src onnxruntime_add_shared_library_module(onnxruntime4j_jni ${onnxruntime4j_native_src}) set_property(TARGET onnxruntime4j_jni PROPERTY C_STANDARD 11) +if (APPLE) + set_target_properties(onnxruntime4j_jni PROPERTIES + MACOSX_RPATH TRUE + SKIP_BUILD_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + BUILD_WITH_INSTALL_NAME_DIR TRUE + INSTALL_NAME_DIR @rpath) +endif() + # depend on java sources. if they change, the JNI should recompile add_dependencies(onnxruntime4j_jni onnxruntime4j) onnxruntime_add_include_to_target(onnxruntime4j_jni onnxruntime_session) @@ -166,6 +175,28 @@ if (WIN32) if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_WEBGPU) + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + if (onnxruntime_USE_VCPKG) + add_custom_command( + TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $ + ${JAVA_PACKAGE_LIB_DIR}/ + ) + else() + add_custom_command( + TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different + $/dxil.dll + $/dxcompiler.dll + ${JAVA_PACKAGE_LIB_DIR}/ + ) + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() + endif() endif() else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) @@ -188,6 +219,9 @@ else() if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_WEBGPU AND onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index 355575be3bcf7..4e09400ac84b8 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -74,8 +74,13 @@ endif() if (onnxruntime_USE_WEBGPU) set(NODEJS_BINDING_USE_WEBGPU "--use_webgpu") if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - list(APPEND NODEJS_DLL_DEPS "$/dxil.dll") - list(APPEND NODEJS_DLL_DEPS "$/dxcompiler.dll") + if (onnxruntime_USE_VCPKG) + list(APPEND NODEJS_DLL_DEPS "$") + list(APPEND NODEJS_DLL_DEPS "$") + else() + list(APPEND NODEJS_DLL_DEPS "$/dxil.dll") + list(APPEND NODEJS_DLL_DEPS "$/dxcompiler.dll") + endif() endif() if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) list(APPEND NODEJS_DLL_DEPS "$") diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 4bbca7b1b811a..7a7e0d39fcd2d 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -59,8 +59,20 @@ list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") endif() - list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + if (onnxruntime_USE_VCPKG) + # Fix Dawn vcpkg build issue (missing IMPORTED_IMPLIB and IMPORTED_LOCATION for target dawn::webgpu_dawn) + get_target_property(webgpu_dawn_target_IMPORTED_IMPLIB dawn::webgpu_dawn IMPORTED_IMPLIB) + if (NOT webgpu_dawn_target_IMPORTED_IMPLIB) + set_target_properties(dawn::webgpu_dawn PROPERTIES IMPORTED_IMPLIB "webgpu_dawn.lib") + endif() + get_target_property(webgpu_dawn_target_IMPORTED_LOCATION dawn::webgpu_dawn IMPORTED_LOCATION) + if (NOT webgpu_dawn_target_IMPORTED_LOCATION) + set_target_properties(dawn::webgpu_dawn PROPERTIES IMPORTED_LOCATION "webgpu_dawn.dll") + endif() + endif() endif() + + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") else() if (NOT onnxruntime_USE_EXTERNAL_DAWN) target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) @@ -70,11 +82,19 @@ if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) # Ensure dxil.dll and dxcompiler.dll exist in the output directory $ - add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) - add_dependencies(onnxruntime_providers_webgpu dxcompiler) + if (onnxruntime_USE_VCPKG) + find_package(directx-dxc CONFIG REQUIRED) + target_link_libraries(onnxruntime_providers_webgpu Microsoft::DirectXShaderCompiler) + target_link_libraries(onnxruntime_providers_webgpu Microsoft::DXIL) + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + else() + add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) + add_dependencies(onnxruntime_providers_webgpu dxcompiler) - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") + endif() endif() if (onnxruntime_providers_webgpu_dll_deps) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 8f7a96e052fa1..7b91e65306bdb 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1073,6 +1073,36 @@ if (onnxruntime_USE_QNN) endif() endif() +if (onnxruntime_USE_WEBGPU) + if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + if (onnxruntime_USE_VCPKG) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $ + $/onnxruntime/capi/ + ) + else() + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $/dxil.dll + $/dxcompiler.dll + $/onnxruntime/capi/ + ) + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) + endif() +endif() + if (onnxruntime_USE_VSINPU) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/vcpkg-ports/dawn/dawn.patch b/cmake/vcpkg-ports/dawn/dawn.patch new file mode 100644 index 0000000000000..1fe66d2cf917d --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn.patch @@ -0,0 +1,59 @@ +diff --git a/src/cmake/DawnCompilerPlatformFlags.cmake b/src/cmake/DawnCompilerPlatformFlags.cmake +index 50638e2456..efa42711e6 100644 +--- a/src/cmake/DawnCompilerPlatformFlags.cmake ++++ b/src/cmake/DawnCompilerPlatformFlags.cmake +@@ -63,7 +63,3 @@ endif () + if (MSVC AND NOT COMPILER_IS_CLANG_CL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") + endif () +- +-if (TARGET_MACOS) +- set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version" FORCE) +-endif () +\ No newline at end of file +diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp +index 5bfac41dcc..71a153daaa 100644 +--- a/third_party/emdawnwebgpu/webgpu.cpp ++++ b/third_party/emdawnwebgpu/webgpu.cpp +@@ -692,6 +692,7 @@ struct WGPUBufferImpl final : public EventSource, + WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); + // Injection constructor used when we already have a backing Buffer. + WGPUBufferImpl(const EventSource* source, WGPUBufferMapState mapState); ++ ~WGPUBufferImpl(); + + void Destroy(); + const void* GetConstMappedRange(size_t offset, size_t size); +@@ -1361,6 +1362,12 @@ WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, + RefCountedWithExternalCount(kImportedFromJS), + mMapState(mapState) {} + ++WGPUBufferImpl::~WGPUBufferImpl() { ++ if (!IsImported()) { ++ Destroy(); ++ } ++} ++ + void WGPUBufferImpl::Destroy() { + emwgpuBufferDestroy(this); + AbortPendingMap("Buffer was destroyed before mapping was resolved."); +diff --git a/src/tint/utils/memory/aligned_storage.h b/src/tint/utils/memory/aligned_storage.h +index c532c4fc38..19c950af4c 100644 +--- a/src/tint/utils/memory/aligned_storage.h ++++ b/src/tint/utils/memory/aligned_storage.h +@@ -31,6 +31,9 @@ + #include + + #include "src/tint/utils/memory/bitcast.h" ++#include "src/tint/utils/macros/compiler.h" ++ ++TINT_BEGIN_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); + + namespace tint { + +@@ -50,4 +53,6 @@ struct alignas(alignof(T)) AlignedStorage { + + } // namespace tint + ++TINT_END_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); ++ + #endif // SRC_TINT_UTILS_MEMORY_ALIGNED_STORAGE_H_ diff --git a/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch b/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch new file mode 100644 index 0000000000000..2d999a456fdec --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch @@ -0,0 +1,19 @@ +diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +index 158f10764c..a324c101ed 100644 +--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp ++++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +@@ -269,11 +269,9 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() { + mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) { +- // TODO(crbug.com/tint/2164): Investigate crashes in f16 CTS tests to enable on NVIDIA. +- if (!gpu_info::IsNvidia(GetVendorId())) { +- EnableFeature(Feature::ShaderF16); +- shaderF16Enabled = true; +- } ++ // ONNX Runtime Patch: enable shaderF16 on all devices. ++ EnableFeature(Feature::ShaderF16); ++ shaderF16Enabled = true; + } + + if (mDeviceInfo.HasExt(DeviceExt::DrawIndirectCount) && diff --git a/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch b/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch new file mode 100644 index 0000000000000..6e97475c8ad53 --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch @@ -0,0 +1,125 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index b46b68204b..3e985ae3cd 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -127,6 +127,8 @@ if (DAWN_SUPPORTS_GLFW_FOR_WINDOWING) + set(BUILD_SAMPLES ON) + endif() + ++option(DAWN_ENABLE_VCPKG "Enable vcpkg integration" OFF) ++ + option(DAWN_ENABLE_ASAN "Enable address sanitizer" OFF) + option(DAWN_ENABLE_INSTALL "Enable install step for Dawn libraries" OFF) + option(DAWN_ENABLE_TSAN "Enable thread sanitizer" OFF) +@@ -439,16 +441,25 @@ set(TINT_SPIRV_TOOLS_DIR ${DAWN_SPIRV_TOOLS_DIR}) + ################################################################################ + # Run on all subdirectories + ################################################################################ +-if (DAWN_BUILD_PROTOBUF AND EXISTS "${DAWN_PROTOBUF_DIR}/cmake") +- if (("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") AND WIN32) +- set(protobuf_HAVE_BUILTIN_ATOMICS 1) ++if (DAWN_ENABLE_VCPKG) ++ find_package(absl REQUIRED) ++ find_package(SPIRV-Headers REQUIRED) ++ find_package(SPIRV-Tools REQUIRED) ++ if (DAWN_USE_BUILT_DXC) ++ find_package(directx-dxc CONFIG REQUIRED) + endif() ++else() ++ if (DAWN_BUILD_PROTOBUF AND EXISTS "${DAWN_PROTOBUF_DIR}/cmake") ++ if (("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") AND WIN32) ++ set(protobuf_HAVE_BUILTIN_ATOMICS 1) ++ endif() + +- # Needs to come before SPIR-V Tools +- include("third_party/protobuf.cmake") +-endif() ++ # Needs to come before SPIR-V Tools ++ include("third_party/protobuf.cmake") ++ endif() + +-add_subdirectory(third_party) ++ add_subdirectory(third_party) ++endif() + + # TODO(crbug.com/tint/455): Tint does not currently build with CMake when + # BUILD_SHARED_LIBS=1, so always build it as static for now. +diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt +index d3128bf764..319a847311 100644 +--- a/src/dawn/native/CMakeLists.txt ++++ b/src/dawn/native/CMakeLists.txt +@@ -865,7 +865,9 @@ if (DAWN_ENABLE_D3D12) + if (DAWN_USE_BUILT_DXC) + target_compile_definitions(dawn_native PRIVATE "DAWN_USE_BUILT_DXC") + target_compile_definitions(dawn_native_objects PRIVATE "DAWN_USE_BUILT_DXC") +- add_dependencies(dawn_native copy_dxil_dll) ++ if (NOT DAWN_ENABLE_VCPKG) ++ add_dependencies(dawn_native copy_dxil_dll) ++ endif() + endif() + endif() + +@@ -942,5 +944,9 @@ endif () + # They happen because dxcompiler is declared a shared library and bundle_libraries + # doesn't work well with shared libs + if (DAWN_USE_BUILT_DXC) +- target_link_libraries(dawn_native PRIVATE dxcompiler) ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(dawn_native PRIVATE Microsoft::DirectXShaderCompiler) ++ else() ++ target_link_libraries(dawn_native PRIVATE dxcompiler) ++ endif() + endif() +diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt +index 8692171222..b3da2fbbbf 100644 +--- a/src/tint/CMakeLists.txt ++++ b/src/tint/CMakeLists.txt +@@ -214,13 +214,21 @@ function(tint_default_compile_options TARGET) + endfunction() + + function(tint_spvheaders_compile_options TARGET) +- target_link_libraries(${TARGET} PRIVATE SPIRV-Headers) +- target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_HEADERS_DIR}/include") ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Headers::SPIRV-Headers) ++ else () ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Headers) ++ target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_HEADERS_DIR}/include") ++ endif() + endfunction() + + function(tint_spvtools_compile_options TARGET) +- target_link_libraries(${TARGET} PRIVATE SPIRV-Tools) +- target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_TOOLS_DIR}/include") ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools-static) ++ else () ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools) ++ target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_TOOLS_DIR}/include") ++ endif() + endfunction() + + function(tint_lib_compile_options TARGET) +@@ -562,12 +570,16 @@ function(tint_target_add_external_dependencies TARGET KIND) + target_link_libraries(${TARGET} PRIVATE + SPIRV-Tools-opt + ) +- target_include_directories(${TARGET} PRIVATE +- "${TINT_SPIRV_TOOLS_DIR}" +- "${TINT_SPIRV_TOOLS_DIR}/include" +- "${TINT_SPIRV_TOOLS_DIR}/source" +- "${spirv-tools_BINARY_DIR}" +- ) ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools-static) ++ else () ++ target_include_directories(${TARGET} PRIVATE ++ "${TINT_SPIRV_TOOLS_DIR}" ++ "${TINT_SPIRV_TOOLS_DIR}/include" ++ "${TINT_SPIRV_TOOLS_DIR}/source" ++ "${spirv-tools_BINARY_DIR}" ++ ) ++ endif() + elseif(${DEPENDENCY} STREQUAL "thread") + find_package(Threads REQUIRED) + target_link_libraries(${TARGET} PRIVATE Threads::Threads) diff --git a/cmake/vcpkg-ports/dawn/portfile.cmake b/cmake/vcpkg-ports/dawn/portfile.cmake new file mode 100644 index 0000000000000..1c53f8316c372 --- /dev/null +++ b/cmake/vcpkg-ports/dawn/portfile.cmake @@ -0,0 +1,138 @@ +# NOTE: dynamic library vs. static library +# +# We are building Dawn as a shared library `webgpu_dawn`. However, we need to set the `BUILD_SHARED_LIBS` option to +# `OFF` in this portfile. See the explanation below. +# +# In CMake convention, the `BUILD_SHARED_LIBS` option is used to control whether a library is built as a shared library or a static library. +# However, in the Dawn repository, there are multiple targets. Instead of building each target as a shared library, Dawn +# uses a CMake option `DAWN_BUILD_MONOLITHIC_LIBRARY` to control whether to build a monolithic dynamic library. +# +# When `DAWN_BUILD_MONOLITHIC_LIBRARY` is set to `ON`, a single library is built that contains all the targets. The +# library is always built as a shared library, regardless of the value of `BUILD_SHARED_LIBS`. +# +# In the vcpkg migration, we found that when both `DAWN_BUILD_MONOLITHIC_LIBRARY` and `BUILD_SHARED_LIBS` are set to `ON`, the build process will fail with some unexpected errors. +# So we need to set `BUILD_SHARED_LIBS` to `OFF` in this mode. +# +# The following function call ensures BUILD_SHARED_LIBS is set to OFF. +vcpkg_check_linkage(ONLY_STATIC_LIBRARY) + +if(VCPKG_TARGET_IS_EMSCRIPTEN) + message(FATAL_ERROR "This port is currently not supported on Emscripten.") +endif() + +set(onnxruntime_vcpkg_DAWN_OPTIONS) + +list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + + # enable the vcpkg flag + -DDAWN_ENABLE_VCPKG=ON + + # fetch dependencies is disabled when using vcpkg + -DDAWN_FETCH_DEPENDENCIES=OFF + + -DDAWN_BUILD_SAMPLES=OFF + -DDAWN_ENABLE_NULL=OFF + -DDAWN_BUILD_TESTS=OFF +) + +if (NOT VCPKG_TARGET_IS_EMSCRIPTEN) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + + -DDAWN_BUILD_MONOLITHIC_LIBRARY=ON + -DDAWN_ENABLE_INSTALL=ON + + -DDAWN_ENABLE_DESKTOP_GL=OFF + -DDAWN_ENABLE_OPENGLES=OFF + -DDAWN_SUPPORTS_GLFW_FOR_WINDOWING=OFF + -DDAWN_USE_GLFW=OFF + -DDAWN_USE_WINDOWS_UI=OFF + -DTINT_BUILD_GLSL_WRITER=OFF + -DTINT_BUILD_GLSL_VALIDATOR=OFF + + -DDAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG=OFF + -DDAWN_USE_X11=OFF + + -DTINT_BUILD_TESTS=OFF + -DTINT_BUILD_CMD_TOOLS=OFF + -DTINT_BUILD_IR_BINARY=OFF + -DTINT_BUILD_SPV_READER=OFF + -DTINT_BUILD_WGSL_WRITER=ON + + -DDAWN_ENABLE_SPIRV_VALIDATION=OFF + + # explicitly set the jinja2 and markupsafe directories to empty strings + # when they are empty, the python script will import them from the system + # + # pip install jinja2 markupsafe + # + -DDAWN_JINJA2_DIR= + -DDAWN_MARKUPSAFE_DIR= + ) +endif() + +if(VCPKG_TARGET_IS_WINDOWS) + # feature detection on Windows + vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS + FEATURES + windows-use-d3d12 onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12 + windows-use-vulkan onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN + ) + + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_USE_BUILT_DXC=ON + -DTINT_BUILD_HLSL_WRITER=ON + ) + + if((NOT onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN) AND(NOT onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of \"windows-use-d3d12\" or \"windows-use-vulkan\" must be enabled when using Dawn on Windows.") + endif() + + if(onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_VULKAN=ON + -DTINT_BUILD_SPV_WRITER=ON + ) + else() + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_VULKAN=OFF + ) + endif() + + if(onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D12=ON + ) + else() + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D12=OFF + ) + endif() + + # We are currently always using the D3D12 backend. + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D11=OFF + ) +endif() + +vcpkg_from_github( + OUT_SOURCE_PATH SOURCE_PATH + REPO google/dawn + REF "${VERSION}" + SHA512 9771e0be45ad2b85e4d85e12cbf03b9c9b4cc297e8f819e6277d8f02821adb671bf420fd13e241be4f6d7795a3acf0d0a38649c6e0e38a523a6ec0f042591efe + + PATCHES + dawn.patch + dawn_force_enable_f16_nvidia_vulkan.patch + dawn_vcpkg_integration.patch +) + +vcpkg_cmake_configure( + SOURCE_PATH "${SOURCE_PATH}" + WINDOWS_USE_MSBUILD + OPTIONS + ${onnxruntime_vcpkg_DAWN_OPTIONS} + + # MAYBE_UNUSED_VARIABLES +) + +vcpkg_cmake_install() diff --git a/cmake/vcpkg-ports/dawn/vcpkg.json b/cmake/vcpkg-ports/dawn/vcpkg.json new file mode 100644 index 0000000000000..0ea8627f7e17c --- /dev/null +++ b/cmake/vcpkg-ports/dawn/vcpkg.json @@ -0,0 +1,62 @@ +{ + "name": "dawn", + "version-string": "4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb", + "port-version": 1, + "description": "Dawn, a native WebGPU implementation.", + "homepage": "https://dawn.googlesource.com/dawn", + "license": "BSD-3-Clause", + "dependencies": [ + { "name": "vcpkg-cmake", "host": true }, + { "name": "vcpkg-cmake-config", "host": true }, + { "name": "abseil", "version>=": "20250127.1" }, + { "name": "protobuf", "version>=": "3.21.12" }, + { + "name": "spirv-headers", + "version>=": "1.4.304.1", + "platform": "!emscripten" + }, + { + "name": "spirv-tools", + "version>=": "1.4.304.1", + "platform": "!emscripten" + }, + { + "name": "vulkan-headers", + "version>=": "1.4.304.1#1", + "platform": "(windows | linux) & (arm64 | x64)" + }, + { + "name": "vulkan-loader", + "version>=": "1.4.304.1", + "platform": "(windows | linux) & (arm64 | x64)" + }, + { + "name": "vulkan-utility-libraries", + "version>=": "1.4.304.1", + "platform": "(windows | linux) & (arm64 | x64)" + } + ], + "features": { + "windows-use-d3d12": { + "description": "Enable D3D12 backend on Windows.", + "dependencies": [ + { + "name": "directx-dxc", + "version>=": "2025-02-20#1", + "platform": "windows & !arm32" + }, + { + "name": "directx-headers", + "version>=": "1.615.0", + "platform": "windows & !arm32" + } + ] + }, + "windows-use-vulkan": { + "description": "Enable Vulkan backend on Windows." + } + }, + "default-features": [ + { "name": "windows-use-d3d12", "platform": "windows & !arm32" } + ] +} diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index f46abddfa028f..193ba6fe5cad5 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -77,27 +77,23 @@ "features": { "tests": { "description": "Build ONNXRuntime unit tests", - "dependencies": [ - "gtest" - ] + "dependencies": ["gtest"] }, "xnnpack-ep": { "description": "Build with XNNPack EP", - "dependencies": [ - "xnnpack" - ] + "dependencies": ["xnnpack"] }, "coreml-ep": { "description": "Build with CoreML EP", - "dependencies": [ - "fp16" - ] + "dependencies": ["fp16"] }, "dml-ep": { - "description": "Build with CoreML EP", - "dependencies": [ - "directx-headers" - ] + "description": "Build with DirectML EP", + "dependencies": ["directx-headers"] + }, + "webgpu-ep": { + "description": "Build with WebGPU EP", + "dependencies": [{ "name": "dawn", "platform": "!emscripten" }] } }, "overrides": [ diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index c28c79f1e723e..fd813eff2f575 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -79,6 +79,9 @@ final class OnnxRuntime { /** The short name of the ONNX runtime QNN provider library */ static final String ONNXRUNTIME_LIBRARY_QNN_NAME = "onnxruntime_providers_qnn"; + /** The short name of the WebGPU DAWN library */ + static final String ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME = "webgpu_dawn"; + /** The OS & CPU architecture string */ private static final String OS_ARCH_STR = initOsArch(); @@ -162,6 +165,10 @@ static synchronized void init() throws IOException { // the ONNX Runtime native library will load it extractProviderLibrary(ONNXRUNTIME_LIBRARY_SHARED_NAME); + // Extract and prepare the Dawn shared library (if present) but don't try to load it, + // the ONNX Runtime native library will load it + extractProviderLibrary(ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME); + if (!isAndroid()) { load(ONNXRUNTIME_LIBRARY_NAME); } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7246738fd4406..c3f9d345078fe 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -2133,6 +2133,9 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid case QNN: options.addQnn(Collections.singletonMap("backend_type", "cpu")); break; + case WEBGPU: + options.addWebGPU(Collections.emptyMap()); + break; case VITIS_AI: case RK_NPU: case MI_GRAPH_X: diff --git a/requirements-dev.txt b/requirements-dev.txt index b95b85781a398..e89edaa33e98e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ cerberus flatbuffers jinja2 +markupsafe numpy onnx onnxmltools diff --git a/tools/ci_build/github/linux/python/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt index 3ca025514ea3d..f499dae947b4f 100644 --- a/tools/ci_build/github/linux/python/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -9,3 +9,5 @@ sympy==1.12 flatbuffers psutil onnxscript==0.2.3 ; python_version < '3.13' +jinja2 +markupsafe diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 2b222c4b1d4a4..d292f6edacde2 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -9,3 +9,5 @@ sympy==1.12 flatbuffers psutil onnxscript==0.2.3 +jinja2 +markupsafe From b45c7b69ac17a5934895927ddb2560bb5fb5ac95 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 19 May 2025 13:23:06 -0700 Subject: [PATCH 14/57] upgrade emsdk to 4.0.8 (#24798) ### Description upgrade emsdk to 4.0.8 ### Motivation and Context --- .gitmodules | 2 +- cmake/external/emsdk | 2 +- cmake/onnxruntime_webassembly.cmake | 2 +- tools/ci_build/build_args.py | 2 +- .../github/azure-pipelines/templates/linux-wasm-ci.yml | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.gitmodules b/.gitmodules index 7656fc429d005..b5bff01d89850 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.4 + branch = 4.0.8 diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 074211759c17c..419021fa04042 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 074211759c17c646164d3271ca1d155cc174f78e +Subproject commit 419021fa040428bc69ef1559b325addb8e10211f diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index bfb73e14ce7a4..f00292fade52d 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -196,7 +196,7 @@ else() onnxruntime_util re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue','HEAP8','HEAPU8','HEAP32','HEAPU32'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 215ad77335083..807c8b327c780 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -342,7 +342,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.4", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.8", help="Specify version of emsdk.") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 0e60bf8e2e26d..aa434699fbe02 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -88,15 +88,15 @@ jobs: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.4 ccache-git-emscripten-64bit - ./emsdk activate 4.0.4 ccache-git-emscripten-64bit + ./emsdk install 4.0.8 ccache-git-emscripten-64bit + ./emsdk activate 4.0.8 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.4 - ./emsdk activate 4.0.4 + ./emsdk install 4.0.8 + ./emsdk activate 4.0.8 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml From 8eaeec3dcfdac571cb6256cd1c5711fb76193c56 Mon Sep 17 00:00:00 2001 From: "Chunye Wang@AMD" Date: Mon, 19 May 2025 13:40:14 -0700 Subject: [PATCH 15/57] [VITISAI] pass all session configs to vitisai ep for `Ort::CompileModel` flow (#24799) ### Description convert all session configs, i.e. key-value pairs into provider options, the key prefixed with `ort_session_config.` ### Motivation and Context #24445 has a bug when `Ort::CompileModel` is used, not all session config are passed to VITISAI EP backend. It is because that the `session_option` which holds a reference to `VitisiAIExectuionProviderFactory` is not as same as the `session_option` used for `Ort::CompileModel`. `Ort::CompileModel` create another `session_option` behind scene. The symptom of this bug is that only the session configs in the first `SessionOptions` object is passed to `VitisiAIExectuionProviderFactory` and session configs in the second `SessionOptions` are not, so that VITISAI EP backend sometimes assumes that ep.cache_context is not enabled, and then ep context cache model is not created properly. --- onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index fc5cce4257ebe..6849bcfc21f88 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -52,6 +52,8 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const for (const auto& [key, value] : config_options_map) { if (key.rfind(key_prefix, 0) == 0) { provider_options[key.substr(key_prefix.size())] = value; + } else { + provider_options["ort_session_config." + key] = value; } } From 9983680253da54d1c40cb5210e52071f37585eaf Mon Sep 17 00:00:00 2001 From: Nenad Banfic <46795300+nenad1002@users.noreply.github.com> Date: Mon, 19 May 2025 15:24:33 -0700 Subject: [PATCH 16/57] SkipSimplifiedLayerNorm + QuickGelu bfloat16 CUDA implementation (#24772) ### Description SkipSimplifiedLayerNorm + QuickGelu bfloat16 CUDA implementation #24772 ### Motivation and Context --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 4 +- .../cuda/activation/activations.cc | 3 +- .../cuda/activation/activations_impl.cu | 3 +- .../contrib_ops/cuda/bert/layer_norm.cuh | 18 ++++++++ .../contrib_ops/cuda/bert/skip_layer_norm.cc | 43 +++++++++++++------ .../cuda/bert/skip_layer_norm_impl.cu | 9 +++- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 ++ .../core/graph/contrib_ops/bert_defs.cc | 2 +- 9 files changed, 69 insertions(+), 21 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index dbe7d9b85092a..b29fe7adb0da4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5585,8 +5585,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or half tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output to float tensors.
U : tensor(float)
Constrain mean and inv_std_var to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1c30e67534a0c..86b490f8f4c43 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -959,7 +959,7 @@ Do not modify directly.* |QOrderedMatMul|*in* A:**Q**
*in* scale_A:**S**
*in* B:**Q**
*in* scale_B:**S**
*in* scale_Y:**S**
*in* bias:**S**
*in* C:**Q**
*in* scale_C:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)| |QuantizeWithOrder|*in* input:**F**
*in* scale_input:**S**
*out* output:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| -|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |RelativePositionBias|*in* bias_table:**T**
*in* query_length:**U**
*in* key_length:**U**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -968,7 +968,7 @@ Do not modify directly.* |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| -|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 6303858b9bd48..0c4d42b328510 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -44,7 +44,8 @@ namespace cuda { #define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \ UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \ UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \ - UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double) + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double) \ + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, BFloat16) UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 36f33fbb24c18..a11691d22d8be 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -62,7 +62,8 @@ struct OP_QuickGelu : public CtxQuickGelu { #define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \ SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \ SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \ - SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \ + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, BFloat16) #define UNARY_ACTIVATION_OP_NAME(name) \ UNARY_ACTIVATION_IMPL(name); \ diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index ff3178b56c2a6..0953161dc0d44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -25,6 +25,7 @@ limitations under the License. #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" #include +#include #include #include @@ -60,6 +61,15 @@ __device__ inline half2 AddHalf2(const half2 a, const half2 b) { #endif } +template <> +__device__ inline nv_bfloat16 Rsqrt(const nv_bfloat16& x) { + return hrsqrt(x); +} + +__device__ inline nv_bfloat162 AddHalf2(const nv_bfloat162 a, const nv_bfloat162 b) { + return __hadd2(a, b); +} + struct KeyValuePairSum { __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { @@ -78,6 +88,14 @@ struct KeyValuePairSum { const cub::KeyValuePair& b) { return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); } + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value); + const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value); + const nv_bfloat162 res = AddHalf2(a2, b2); + return cub::KeyValuePair(__low2bfloat16(res), __high2bfloat16(res)); + } }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 428b903c03682..92ae7e81fb5bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -34,6 +34,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) using namespace ONNX_NAMESPACE; @@ -106,19 +107,35 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { - LaunchSkipLayerNormKernel( - Stream(ctx), - reinterpret_cast(output->MutableData()), - sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - epsilon_, - hidden_size, - row_count, - skip_size); + if constexpr (std::is_same_v) { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } else { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } } CUDA_RETURN_IF_ERROR(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 50c8e4b5e0398..a1dcab0a6bf89 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/layer_norm.cuh" #include "contrib_ops/cuda/bert/skip_layer_norm_impl.h" #include +#include namespace onnxruntime { namespace contrib { @@ -49,6 +50,11 @@ half maybe2half(float x) { return __float2half_rn(x); } +template <> +nv_bfloat16 maybe2half(float x) { + return __float2bfloat16_rn(x); +} + // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; @@ -263,7 +269,8 @@ SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); SKIPLAYERNORM_IMPL(half, false); - +SKIPLAYERNORM_IMPL(nv_bfloat16, true); +SKIPLAYERNORM_IMPL(nv_bfloat16, false); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index b8931bf1ea0f8..17f3433aed38a 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -42,6 +42,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul); // backward compatibility class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul); // backward compatibility class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul); // backward compatibility @@ -129,6 +130,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu); @@ -256,6 +258,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, // backward compatibility @@ -339,6 +342,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7eea7d218e278..238dd8d4573de 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1669,7 +1669,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); From 809d9a0e0c74e054a979f24f501af2d16be6be80 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 May 2025 22:44:55 +0000 Subject: [PATCH 17/57] Bump ruff from 0.11.9 to 0.11.10 (#24804) --- requirements-lintrunner.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 37df56216bde9..e66ec3bb58d74 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,6 +3,6 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 # RUFF -ruff==0.11.9 +ruff==0.11.10 # CLANGFORMAT clang-format==19.1.7 From 64aa7f02408e1d782aae72a13ccb7cfa50b9ea79 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 19 May 2025 21:41:14 -0700 Subject: [PATCH 18/57] [vcpkg-ports] set EOL to LF for app .patch files (#24812) ### Description ### Motivation and Context --- cmake/vcpkg-ports/.gitattributes | 1 + 1 file changed, 1 insertion(+) create mode 100644 cmake/vcpkg-ports/.gitattributes diff --git a/cmake/vcpkg-ports/.gitattributes b/cmake/vcpkg-ports/.gitattributes new file mode 100644 index 0000000000000..9812ceb1ffd9b --- /dev/null +++ b/cmake/vcpkg-ports/.gitattributes @@ -0,0 +1 @@ +*.patch text eol=lf From 8c1156e07c48eedfebaa1cde9a2af7f13b147f3c Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 19 May 2025 23:33:58 -0700 Subject: [PATCH 19/57] [Compile] Add API to set compile flags; Fix behavior when no nodes are compiled (#24695) ### Description #### Original compile approach where an EPContext model is generated as a side-effect of creating a session: - **Restore** previous behavior where: - compiling a model that generates no EPContext nodes is silently ignored (nothing is generated and no error is reported) - compiling a previously compiled model is silently ignored (nothing is generated and no error is reported) #### Explicit compile API: - **Retains** current behavior where compiling a model that does not generate EPContext nodes still generates a model by default. - Adds C/C++/C#/Python API called `setFlags` that allows the user to specify what is considered an error. - `OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED`: CompileModel() returns `ORT_FAIL` if no nodes were compiled. - `OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS`: CompileModel() returns `ORT_FAIL` if a file with the same filename as the output model exists. - Adds logic to detect when the user is trying to compile a previously compiled model and returns an `ORT_INVALID_GRAPH` error with a relevant error message. ### Motivation and Context A previous [PR changed the default behavior](https://github.com/microsoft/onnxruntime/commit/b4f7a905b0d636b71bd486c0ef702eb5a44eadf2#diff-e2d3910ae7593ee7ba4fd74e53f738fa973ae2fc32c069f1088ba458b91f8d4bL809) of the original "implicit" compilation approach. This PR was motivated by restoring the original behavior that users currently depend on. At the same time, we want to allow users of the new explicit compile API to determine what is considered an error. --- .../CompileModel.shared.cs | 22 +++ .../NativeCompileApiMethods.shared.cs | 61 +++--- .../CompileApiTests.cs | 58 ++++++ .../core/session/onnxruntime_c_api.h | 27 +++ .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 5 + onnxruntime/__init__.py | 1 + .../core/framework/graph_partitioner.cc | 52 +++-- onnxruntime/core/framework/session_options.h | 31 ++- onnxruntime/core/session/compile_api.cc | 23 ++- onnxruntime/core/session/compile_api.h | 2 + .../core/session/model_compilation_options.cc | 18 +- .../core/session/model_compilation_options.h | 7 + .../onnxruntime_inference_collection.py | 5 + .../onnxruntime_pybind_model_compiler.cc | 7 +- .../onnxruntime_pybind_model_compiler.h | 4 +- .../python/onnxruntime_pybind_state.cc | 12 +- .../test/providers/qnn/qnn_ep_context_test.cc | 179 +++++++++++++++++- .../onnxruntime_test_python_compile_api.py | 51 ++++- 19 files changed, 514 insertions(+), 52 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index 9f42bf2247529..c348184658e7e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -6,6 +6,18 @@ namespace Microsoft.ML.OnnxRuntime using System; using System.Runtime.InteropServices; + /// + /// Flags representing options to enable when compiling a model. + /// Matches OrtCompileApiFlags in the ORT C API. + /// + [Flags] + public enum OrtCompileApiFlags : uint + { + NONE = 0, + ERROR_IF_NO_NODES_COMPILED = 1 << 0, + ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, + } + /// /// This class is used to set options for model compilation, and to produce a compiled model using those options. /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. @@ -108,6 +120,16 @@ public void SetEpContextEmbedMode(bool embed) NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); } + /// + /// Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. + /// + /// bitwise OR of flags in OrtCompileApiFlags to enable. + public void SetFlags(OrtCompileApiFlags flags) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); + } + internal IntPtr Handle => handle; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 602bcc6caf7f8..3edc25b307a21 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -20,6 +20,7 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetOutputModelBuffer; public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; public IntPtr CompileModel; + public IntPtr ModelCompilationOptions_SetFlags; } internal class NativeMethods @@ -43,7 +44,7 @@ internal class NativeMethods IntPtr /* const OrtEnv* */ env, IntPtr /* const OrtSessionOptions* */ sessionOptions, out IntPtr /* OrtModelCompilationOptions** */ outOptions); - public DOrtCreateModelCompilationOptionsFromSessionOptions + public DOrtCreateModelCompilationOptionsFromSessionOptions OrtCreateModelCompilationOptionsFromSessionOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -57,7 +58,7 @@ public DOrtCreateModelCompilationOptionsFromSessionOptions IntPtr /* OrtModelCompilationOptions* */ options, byte[] /* const void* */ inputModelData, UIntPtr /* size_t */ inputModelDataSize); - public DOrtModelCompilationOptions_SetInputModelFromBuffer + public DOrtModelCompilationOptions_SetInputModelFromBuffer OrtModelCompilationOptions_SetInputModelFromBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -71,7 +72,7 @@ public DOrtModelCompilationOptions_SetInputModelFromBuffer IntPtr /* OrtModelCompilationOptions* */ options, byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, UIntPtr /* size_t */ externalInitializerSizeThreshold); - public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -94,60 +95,72 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile IntPtr /* const OrtModelCompilationOptions* */ modelOptions); public DOrtCompileModel OrtCompileModel; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetFlags( + IntPtr /* OrtModelCompilationOptions* */ options, + uint flags); + public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { - #if NETSTANDARD2_0 +#if NETSTANDARD2_0 IntPtr compileApiPtr = getCompileApi(); _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); - #else +#else _compileApi = (OrtCompileApi)getCompileApi(); - #endif +#endif - OrtReleaseModelCompilationOptions = + OrtReleaseModelCompilationOptions = (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.ReleaseModelCompilationOptions, + _compileApi.ReleaseModelCompilationOptions, typeof(DOrtReleaseModelCompilationOptions)); - OrtCreateModelCompilationOptionsFromSessionOptions = + OrtCreateModelCompilationOptionsFromSessionOptions = (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.CreateModelCompilationOptionsFromSessionOptions, + _compileApi.CreateModelCompilationOptionsFromSessionOptions, typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); - OrtModelCompilationOptions_SetInputModelPath = + OrtModelCompilationOptions_SetInputModelPath = (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelPath, + _compileApi.ModelCompilationOptions_SetInputModelPath, typeof(DOrtModelCompilationOptions_SetInputModelPath)); - OrtModelCompilationOptions_SetInputModelFromBuffer = + OrtModelCompilationOptions_SetInputModelFromBuffer = (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); - OrtModelCompilationOptions_SetOutputModelPath = + OrtModelCompilationOptions_SetOutputModelPath = (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelPath, + _compileApi.ModelCompilationOptions_SetOutputModelPath, typeof(DOrtModelCompilationOptions_SetOutputModelPath)); - OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); - OrtModelCompilationOptions_SetOutputModelBuffer = + OrtModelCompilationOptions_SetOutputModelBuffer = (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); - OrtModelCompilationOptions_SetEpContextEmbedMode = + OrtModelCompilationOptions_SetEpContextEmbedMode = (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); - OrtCompileModel = + OrtCompileModel = (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( - _compileApi.CompileModel, + _compileApi.CompileModel, typeof(DOrtCompileModel)); + + OrtModelCompilationOptions_SetFlags = + (DOrtModelCompilationOptions_SetFlags)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetFlags, + typeof(DOrtModelCompilationOptions_SetFlags)); + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index 72c165df56418..bf576b54d8b45 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -8,6 +8,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests; using System; using System.Globalization; +using System.IO; using System.Runtime.InteropServices; using Xunit; @@ -61,6 +62,63 @@ public void BasicUsage() allocator.FreeMemory(bytePtr); } + + // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate + // any compiled EPContext nodes, so expect an ORT_FAIL error. + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "should_not_generate.onnx"; + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("Unable to compile any nodes", ex.Message); + } + + Assert.False(File.Exists(output_model_file)); // Output file should not be generated. + } + + // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "squeezenet_ctx.onnx"; + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.CompileModel(); + Assert.True(File.Exists(output_model_file)); + + // Try to compile again with flag that prevents replacing an existing file. + // Expect failure. + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("exists already", ex.Message); + } + + if (File.Exists(output_model_file)) + { + File.Delete(output_model_file); + } + } } } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0d2da44971b3a..0ee8effaa16d0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5782,6 +5782,21 @@ struct OrtModelEditorApi { * ORT Compile API */ +/** \brief Flags representing options to enable when compiling a model. + */ +typedef enum OrtCompileApiFlags { + // Default. Do not enable any additional compilation options. + OrtCompileApiFlags_NONE = 0, + + // Force compilation to return an error (ORT_FAIL) if no nodes were compiled. + // Otherwise, a model with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED = 1 << 0, + + // Force compilation to return an error (ORT_FAIL) if a file with the same filename as the output model exists. + // Otherwise, compilation will automatically overwrite the output file if it exists. + OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, +} OrtCompileApiFlags; + /** * \brief The OrtCompileApi struct provides functions to compile ONNX models. * @@ -5964,6 +5979,18 @@ struct OrtCompileApi { * \since Version 1.22. */ ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); + + /** \brief Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] flags bitwise OR of flags in OrtCompileApiFlags to enable. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, + size_t flags); }; ORT_RUNTIME_CLASS(Ep); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bf8e57894d384..c7f81264115c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1160,6 +1160,7 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0d0b3198a8736..6cd52732b923b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -832,6 +832,11 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags)); + return *this; +} + namespace detail { template diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 6ef0707f4b7c6..56268369bf98a 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -30,6 +30,7 @@ NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 + OrtCompileApiFlags, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtHardwareDevice, # noqa: F401 diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index b39d0dbd25f8d..9a2991ab02730 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -767,7 +767,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, const std::filesystem::path& model_path, std::filesystem::path& context_cache_path, - bool allow_overwrite_output_model = false) { + bool error_if_output_file_exists = true) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; if (!(context_cache_path.has_filename() && context_cache_path.extension() != "")) { @@ -784,9 +784,9 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } - if (std::filesystem::exists(context_cache_path) && !allow_overwrite_output_model) { + if (std::filesystem::exists(context_cache_path) && error_if_output_file_exists) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); + context_cache_path, "' exists already. Please remove the EP context model if you want to re-generate it."); } return Status::OK(); @@ -803,16 +803,25 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } if (all_ep_context_nodes.size() < 1) { - ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes, - "Compiled model does not contain any EPContext nodes. " - "Check that the session EPs support compilation and can execute at least one model subgraph."); - - LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. " - "Either the session EPs do not support compilation or " - "no subgraphs were able to be compiled."; + auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes; + + ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError, + "Unable to compile any nodes. Check that the session EPs support compilation and can execute " + "at least one subgraph in the model."); + + if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { + LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. " + "Either the session EPs do not support compilation or the model is already compiled."; + // Note: this path is only taken if a model is compiled with the original compilation approach that uses + // session options configs only. The explicit compile API instead only chooses between + // kReturnError and kGenerateModel. + return Status::OK(); + } - // we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not - // EPContext nodes. + // Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added) + assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel); + LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. " + "Either the session EPs do not support compilation or the model is already compiled."; } auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { @@ -833,9 +842,21 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + ep_context_gen_options.error_if_output_file_exists)); } + // Utility function to detect a fused node with an unsupported domain. + // Ex: when compiling an already compiled model, an EPContext node in the input model would be wrapped + // into a fused node with a domain like "QNN". Such fused nodes do not pass ONNX correctness checks, so + // we should detect them here and return a better error message. Otherwise, an ORT_INVALID_GRAPH error is raised + // with a confusing error message *after* the invalid model has been saved/generated. + // Note: This only applies to the explicit compile API. The original compilation approach (via session options), + // early exits above (without error) if the model is already compiled. + auto is_invalid_fused_node = [&graph](const Node& node) { + const std::unordered_map& supported_domains = graph.DomainToVersionMap(); + return (node.NodeType() == Node::Type::Fused) && (supported_domains.find(node.Domain()) == supported_domains.end()); + }; + Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()}, @@ -872,6 +893,9 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers // Use EpContext node created by the EPs if name matched, otherwise use node from original model if (ep_context_node.first) { ep_graph.AddNode(*ep_context_node.second); + } else if (is_invalid_fused_node(node)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Encountered an invalid node while compiling a model. ", + "Please ensure the input model is not already compiled."); } else { ep_graph.AddNode(node); } @@ -1216,7 +1240,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::filesystem::path context_cache_path; ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + ep_context_gen_options.error_if_output_file_exists)); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 89a43c4f71ee6..b95b38d007fbb 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -70,15 +70,42 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; +/// +/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes). +/// There are two ways to compile a model: +/// 1. By specifying the correct session option configurations and creating an inference session. +/// The compiled model is generated as a side-effect of session creation. +/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h). +/// +/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain +/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to +/// provide its own defaults (see core/session/model_compilation_options.h/cc). +/// struct EpContextModelGenerationOptions { + // Action to take if the output model does not have compiled (EPContext) nodes. + enum class ActionIfNoCompiledNodes { + // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior + // to maintain compatibility. The explicit compile API does *not* use this action. + kDontGenerateModel = 0, + + // Generate an output model even if it doesn't have compiled nodes. + // The explicit Compile API defaults to this value. + kGenerateModel, + + // Return an error if the model does not have compiled nodes. + // The explicit Compile API can be configured to this value. + kReturnError, + }; + EpContextModelGenerationOptions() = default; // Initializes from string key/value pairs in session config options. + // This initializes this struct from options set via the older compiling approach #1 above. explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); bool enable = false; - bool overwrite_existing_output_file = false; - bool error_if_no_compiled_nodes = false; + bool error_if_output_file_exists = true; + ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; bool embed_ep_context_in_model = false; std::string output_model_file_path; diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index ad128fee6cc3d..d910e3ea74b57 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -201,6 +201,21 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, + _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetFlags(flags)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(flags); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -217,8 +232,9 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, } static constexpr OrtCompileApi ort_compile_api = { - // NOTE: The C# bindings depend on the Api order within this struct so all additions must be at the end, - // and no functions can be removed (the implementation needs to change to return an error). + // NOTE: Application compatibility with newer versions of ORT depends on the Api order within this struct so + // all new functions must be added at the end, and no functions that already exist in an officially released version + // of ORT can be reordered or removed. &OrtCompileAPI::ReleaseModelCompilationOptions, &OrtCompileAPI::CreateModelCompilationOptionsFromSessionOptions, @@ -229,6 +245,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, + // End of Version 22 - DO NOT MODIFY ABOVE + + &OrtCompileAPI::ModelCompilationOptions_SetFlags, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index b8c5211526b9d..5f11b894f2004 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -28,5 +28,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelC ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, + size_t flags); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index d0cb092f78843..5de0f03fafc08 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -18,10 +18,11 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& session_options_.value.has_explicit_ep_context_gen_options = true; session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); session_options_.value.ep_context_gen_options.enable = true; - session_options_.value.ep_context_gen_options.overwrite_existing_output_file = true; - // defaulting to false to support wider usage. will log WARNING if compiling model with no context nodes. - // TODO: Add ability for user to explicitly set this. - session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = false; + session_options_.value.ep_context_gen_options.error_if_output_file_exists = false; + + // defaulting to kGenerateModel to support wider usage. + session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes = + EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); @@ -104,6 +105,15 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m return Status::OK(); } +Status ModelCompilationOptions::SetFlags(size_t flags) { + EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; + options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; + options.action_if_no_compiled_nodes = + (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError + : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + return Status::OK(); +} + const OrtSessionOptions& ModelCompilationOptions::GetSessionOptions() const { return session_options_; } diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 9238264003645..f96f0317cdaca 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -80,6 +80,13 @@ class ModelCompilationOptions { /// Status indicating potential error Status SetEpContextEmbedMode(bool embed_ep_context_in_model); + /// + /// Sets flags representing enabled boolean options defined in OrtCompileApiFlags. + /// + /// unsigned integer set to the bitwise OR of enabled flags. + /// Status indicating success or an error + Status SetFlags(size_t flags); + /// /// Returns a reference to the session options object. /// diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 15c423d7285bc..e8e51db13bcd3 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -644,6 +644,7 @@ def __init__( embed_compiled_data_into_model: bool = False, external_initializers_file_path: str | os.PathLike | None = None, external_initializers_size_threshold: int = 1024, + flags: int = C.OrtCompileApiFlags.NONE, ): """ Creates a ModelCompiler instance. @@ -658,6 +659,8 @@ def __init__( initializers for non-compiled nodes. :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path` is None or empty. Initializers larger than this threshold are stored in the external initializers file. + :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of + flags in onnxruntime.OrtCompileApiFlags. """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -688,6 +691,7 @@ def __init__( embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, + flags, ) else: self._model_compiler = C.ModelCompiler( @@ -697,6 +701,7 @@ def __init__( embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, + flags, ) def compile_to_file(self, output_model_path: str | None = None): diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index 8bb7ee2098caf..4676efa13440b 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -19,7 +19,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; @@ -38,6 +39,10 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrTrue to embed compiled binary data into EPContext nodes. /// The file into which to store initializers for non-compiled /// nodes. + /// Flags from OrtCompileApiFlags /// Ignored if 'external_initializers_file_path' is empty. /// Initializers with a size greater than this threshold are dumped into the external file. /// A Status indicating error or success. @@ -44,7 +45,8 @@ class PyModelCompiler { std::string&& input_model_path_or_bytes, bool input_model_is_path, bool embed_compiled_data_into_model = false, const std::string& external_initializers_file_path = {}, - size_t external_initializers_size_threshold = 1024); + size_t external_initializers_size_threshold = 1024, + size_t flags = 0); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index aa2c0cc6a0f86..12a44e65e247b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2782,6 +2782,11 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + py::enum_(m, "OrtCompileApiFlags", py::arithmetic()) + .value("NONE", OrtCompileApiFlags_NONE) + .value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) + .value("ERROR_IF_OUTPUT_FILE_EXISTS", OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS); + py::class_(m, "ModelCompiler", R"pbdoc(This is the class used to compile an ONNX model.)pbdoc") .def(py::init([](const PySessionOptions& sess_options, @@ -2789,14 +2794,16 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool is_path, bool embed_compiled_data_into_model = false, std::string external_initializers_file_path = {}, - size_t external_initializers_size_threshold = 1024) { + size_t external_initializers_size_threshold = 1024, + size_t flags = OrtCompileApiFlags_NONE) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, std::move(path_or_bytes), is_path, embed_compiled_data_into_model, external_initializers_file_path, - external_initializers_size_threshold)); + external_initializers_size_threshold, + flags)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2805,6 +2812,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(embed_compiled_data_into_model); ORT_UNUSED_PARAMETER(external_initializers_file_path); ORT_UNUSED_PARAMETER(external_initializers_size_threshold); + ORT_UNUSED_PARAMETER(flags); ORT_THROW("Compile API is not supported in this build."); #endif })) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index aeb10b8b4294b..6ef831c8ecd6f 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -14,6 +14,8 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") + using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; @@ -361,6 +363,9 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) { // Make sure the compiled model was generated and has the expected number of EPContext nodes. ASSERT_TRUE(std::filesystem::exists(output_model_file)); CheckEpContextNodeCounts(output_model_file, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so))); } // Test using the CompileModel() API with settings: @@ -396,6 +401,9 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe // Make sure the compiled model was generated and has the expected number of EPContext nodes. ASSERT_TRUE(std::filesystem::exists(output_model_file)); CheckEpContextNodeCounts(output_model_file, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so))); } // Test using the CompileModel() API with settings: @@ -436,6 +444,12 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) { // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + { + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, so))); + } + allocator.Free(output_model_buffer); } @@ -479,6 +493,10 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + allocator.Free(output_model_buffer); } @@ -503,6 +521,10 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + allocator.Free(output_model_buffer); } } @@ -554,9 +576,164 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, so))); + allocator.Free(output_model_buffer); } +// Test that the explicit compile API can be configured to return an error if the output model does not +// have EPContext nodes. +TEST_F(QnnHTPBackendTests, CompileApi_SetFlags_ErrorIfNoCompiledNodes) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("should_not_be_generated.onnx"); + std::filesystem::remove(output_model_file); + + // Initialize session options with only CPU EP, which will not be able to compile any nodes. + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + + // Call CompileModel() but expect an error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_FAIL); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("Unable to compile any nodes")); + + // Make sure that the output file was *NOT* generated. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); +} + +// Test that the explicit compile API can be configured to return an error if the output model already exists and +// would have been overwritten. +TEST_F(QnnHTPBackendTests, CompileApi_SetFlags_ErrorIfOutputFileAlreadyExists) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("mul_1_ctx_.onnx"); + std::filesystem::remove(output_model_file); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kQnnExecutionProvider, ProviderOptions{{"backend_type", "htp"}}); + + // Compile with QNN EP. Should succeed the first time. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << "CompileModel() should succeed the first time a model is compiled."; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "compiled model should exist"; + } + + // Compiling the input model again should fail if we disallow overwriting the output file. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_FAIL); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("exists already")); + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "original compiled model should still exist"; + } +} + +// Tests that the explicit compile API returns an error if user tries to compile a compiled model. +// This scenario is silently ignored in the original compilation approach with session option configs. +TEST_F(QnnHTPBackendTests, CompileApi_ErrorIfCompilingACompiledModel) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("mul_1_ctx_.onnx"); + std::filesystem::remove(output_model_file); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kQnnExecutionProvider, ProviderOptions{{"backend_type", "htp"}}); + + // Compile with QNN EP. Should succeed the first time. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << "CompileModel() should succeed the first time a model is compiled."; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "compiled model should exist"; + } + + // Compiling the compiled model should always fail: it's already compiled! + { + const ORTCHAR_T* new_output_model_file = ORT_TSTR("should_not_be_generated.onnx"); // Should not be generated. + std::filesystem::remove(new_output_model_file); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(output_model_file); // Set the compiled model as the input! + compile_options.SetOutputModelPath(new_output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_INVALID_GRAPH); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("ensure the input model is not already compiled")); + ASSERT_FALSE(std::filesystem::exists(new_output_model_file)) << "new compiled model should not be generated"; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "original compiled model should still exist"; + } +} + +// Uses the original compiling approach with session option configs (instead of explicit compile API). +// Test that ORT does not generate an output model if the model does not contain EPContext nodes. +// Also, ORT should not return an error. +TEST_F(QnnHTPBackendTests, QnnContextBinary_OriginalCompileApproach_NoCompiledNodesDoesntGenerateOutput) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const char* output_model_file = "should_not_be_generated.onnx"; + + // Initialize session options with only CPU EP, which will not be able to compile any nodes. + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_file); + Ort::Session session(*ort_env, input_model_file, so); // Should not throw an error. + + // Make sure that the output file was *NOT* generated. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); +} + +// Uses the original compiling approach with session option configs (instead of explicit compile API). +// Test that ORT does not generate an output model if the input model is already compiled. +// Also, ORT should not return an error. +TEST_F(QnnHTPBackendTests, QnnContextBinary_OriginalCompileApproach_IgnoreCompilingOfCompiledModel) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const char* output_model_file = "mul_1_ctx.onnx"; + std::filesystem::remove(output_model_file); + + ProviderOptions qnn_options = {{"backend_type", "htp"}}; + + // Compile a model with QNN. This should succeed. + { + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_file); + so.AppendExecutionProvider(kQnnExecutionProvider, qnn_options); + + Ort::Session session(*ort_env, input_model_file, so); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); // check compiled model was generated. + } + + // Try compiling the compiled model again. ORT should basically ignore it. + { + const char* new_output_model_file = "should_not_be_generated.onnx"; // will not be generated! + std::filesystem::remove(new_output_model_file); + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, new_output_model_file); + so.AppendExecutionProvider(kQnnExecutionProvider, qnn_options); + + Ort::Session session(*ort_env, ToPathString(output_model_file).c_str(), so); + + // Session creation should not throw an error. And a new output model should not have been generated. + ASSERT_FALSE(std::filesystem::exists(new_output_model_file)); + } +} + // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { @@ -771,7 +948,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { FAIL(); // Should not get here! } catch (const Ort::Exception& excpt) { ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL); - ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already.")); + ASSERT_THAT(excpt.what(), testing::HasSubstr("exists already.")); auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file); ASSERT_EQ(modify_time_1, modify_time_2); } diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 7a410d4bbeb6a..b102676860444 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -13,7 +13,7 @@ from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -120,6 +120,55 @@ def test_compile_with_input_and_output_files(self): model_compiler.compile_to_file(output_model_path) self.assertTrue(os.path.exists(output_model_path)) + def test_compile_flags_error_if_no_compiled_nodes(self): + """ + Tests specifying an additional flag (OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED) that + makes compiling return an error if no compiled nodes are generated (e.g., by using CPU EP). + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + flags=onnxrt.OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED, + ) + + # Compiling should raise a Fail exception and the output model should not be generated + with self.assertRaises(Fail) as context: + model_compiler.compile_to_file(output_model_path) + self.assertIn("Unable to compile any nodes", str(context.exception)) + self.assertFalse(os.path.exists(output_model_path)) + + def test_compile_flags_error_if_output_file_exists(self): + """ + Tests specifying an additional flag (OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS) that + makes compiling return an error the output model file already exists. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + # Compile the first time (should be fine) + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + flags=onnxrt.OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS, + ) + + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) # Output model was generated + + # Compiling again should raise a Fail exception saying that the model already exists. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_file(output_model_path) + self.assertIn("exists already", str(context.exception)) + def test_compile_to_file_with_input_model_in_buffer(self): """ Tests compiling an input model that is stored in a buffer. The output is saved to a file. From ac0195b6dfd6b5de3d82b227c0dfeb37c9285854 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 20 May 2025 09:03:21 -0700 Subject: [PATCH 20/57] support more data types in tensor dumper (#24813) ### Description support of dumping tensors of int8, uin t8, BFloat16, UInt4x2, and Int4x2 data types in the tensor dumper. ### Motivation and Context Help debugging of operators using these data types. --- .../contrib_ops/cpu/utils/console_dumper.h | 39 +-- .../contrib_ops/cpu/utils/dump_tensor.cc | 286 ++++++++-------- .../contrib_ops/cpu/utils/dump_tensor.h | 37 +-- .../cuda/utils/dump_cuda_tensor.cc | 307 +++++++----------- .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 51 ++- 5 files changed, 313 insertions(+), 407 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 64bd2b7b1855e..9f3d22b9b3c0f 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -17,34 +17,31 @@ class IConsoleDumper { virtual ~IConsoleDumper() {} void Disable() { is_enabled_ = false; } bool IsEnabled() const { return is_enabled_; } - virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0; - - virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const = 0; - - virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - - virtual void Print(const char* name, const int32_t* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const int64_t* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const float* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const Tensor& value) const = 0; virtual void Print(const char* name, const OrtValue& value) const = 0; virtual void Print(const char* name, int index, bool end_line) const = 0; virtual void Print(const char* name, const std::string& value, bool end_line) const = 0; - virtual void Print(const std::string& value) const = 0; +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, gsl::span& dims) const = 0; + + TENSOR_DUMPER_PRINT_TYPE(int8_t) + TENSOR_DUMPER_PRINT_TYPE(uint8_t) + TENSOR_DUMPER_PRINT_TYPE(int32_t) + TENSOR_DUMPER_PRINT_TYPE(int64_t) + TENSOR_DUMPER_PRINT_TYPE(float) + TENSOR_DUMPER_PRINT_TYPE(MLFloat16) + TENSOR_DUMPER_PRINT_TYPE(BFloat16) + TENSOR_DUMPER_PRINT_TYPE(UInt4x2) + TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE + protected: bool is_enabled_; }; diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 7755f9505d99d..7cbf989a44878 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -62,6 +62,54 @@ void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di } } +template +void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + + if (onnxruntime::utils::kDefaultSnippetThreshold < static_cast(dim0 * dim1 * dim2 * dim3)) { + for (int i = 0; i < dim0; i++) { + std::cout << "[" << i << "]:" << std::endl; + onnxruntime::utils::PrintCpuTensorSnippet(tensor + i * dim1 * dim2 * dim3, dim1, dim2, dim3, + onnxruntime::utils::kDefaultSnippetEdgeItems); + } + } else { + for (int i = 0; i < dim0; i++) { + std::cout << "[" << i << "]:" << std::endl; + onnxruntime::utils::PrintCpuTensorFull(tensor + i * dim1 * dim2 * dim3, dim1, dim2, dim3); + } + } +} + +void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2, int dim3) { + MLDataType dataType = tensor.DataType(); + if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else { + assert(0); + } +} + void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) { MLDataType dataType = tensor.DataType(); if (dataType == DataTypeImpl::GetType()) { @@ -72,6 +120,14 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); } else if (dataType == DataTypeImpl::GetType()) { DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); } else { assert(0); } @@ -87,11 +143,23 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { DumpCpuTensor(name, tensor.Data(), dim0, dim1); } else if (dataType == DataTypeImpl::GetType()) { DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); } else { assert(0); } } +void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0) { + DumpCpuTensor(name, tensor, 1, dim0); +} + void DumpCpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); @@ -101,21 +169,33 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) { std::cout << "Shape:" << shape << std::endl; size_t num_dims = shape.NumDimensions(); - if (num_dims >= 3) { - int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); - int dim1 = static_cast(shape[num_dims - 2]); - int dim2 = static_cast(shape[num_dims - 1]); + if (num_dims >= 4) { + int dim0 = static_cast(shape.SizeToDimension(num_dims - 4)); + int dim1 = static_cast(shape[num_dims - 3]); + int dim2 = static_cast(shape[num_dims - 2]); + int dim3 = static_cast(shape[num_dims - 1]); + DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2, dim3); + return; + } + + if (num_dims == 3) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + int dim2 = static_cast(shape[2]); DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } - auto num_items = shape.Size(); - size_t num_rows = 1; - if (num_dims > 1) { - num_rows = static_cast(shape[0]); + if (num_dims == 2) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + DumpCpuTensor(nullptr, tensor, dim0, dim1); + return; + } + + if (num_dims == 1) { + DumpCpuTensor(nullptr, tensor, static_cast(shape[0])); } - size_t row_size = num_items / num_rows; - DumpCpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } CpuTensorConsoleDumper::CpuTensorConsoleDumper() { @@ -133,84 +213,6 @@ void CpuTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - void CpuTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { if (!is_enabled_) return; @@ -246,21 +248,39 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b } } -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); +void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { + if (!is_enabled_) + return; + DumpCpuTensor(name, tensor, dim0, dim1); } -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1, dim2); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1, dim2, dim3); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, gsl::span& dims) const { \ + PrintTensorByDims(this, name, tensor, dims); \ + } -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +TENSOR_DUMPER_PRINT_TYPE(int8_t) +TENSOR_DUMPER_PRINT_TYPE(uint8_t) +TENSOR_DUMPER_PRINT_TYPE(int32_t) +TENSOR_DUMPER_PRINT_TYPE(int64_t) +TENSOR_DUMPER_PRINT_TYPE(float) +TENSOR_DUMPER_PRINT_TYPE(MLFloat16) +TENSOR_DUMPER_PRINT_TYPE(BFloat16) +TENSOR_DUMPER_PRINT_TYPE(UInt4x2) +TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE #else @@ -270,45 +290,6 @@ CpuTensorConsoleDumper::CpuTensorConsoleDumper() { void CpuTensorConsoleDumper::Print(const std::string&) const { } -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int, int) const { -} - void CpuTensorConsoleDumper::Print(const char*, const Tensor&) const { } @@ -321,17 +302,30 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const { void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { +void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { -} +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, gsl::span&) const { \ + } -void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { -} +TENSOR_DUMPER_PRINT_TYPE(int8_t) +TENSOR_DUMPER_PRINT_TYPE(uint8_t) +TENSOR_DUMPER_PRINT_TYPE(int32_t) +TENSOR_DUMPER_PRINT_TYPE(int64_t) +TENSOR_DUMPER_PRINT_TYPE(float) +TENSOR_DUMPER_PRINT_TYPE(MLFloat16) +TENSOR_DUMPER_PRINT_TYPE(BFloat16) +TENSOR_DUMPER_PRINT_TYPE(UInt4x2) +TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { -} #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index 6fc4dfd4a0671..5066c3ddbb4b3 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -14,26 +14,8 @@ class CpuTensorConsoleDumper : public IConsoleDumper { public: CpuTensorConsoleDumper(); virtual ~CpuTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; - void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; - void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; - void Print(const char* name, const float* tensor, gsl::span& dims) const override; - void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; @@ -47,6 +29,23 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const std::vector& vec, size_t max_count = 0) const { this->Print(name, vec.data(), 1, static_cast(std::min(max_count, vec.size()))); } + +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const override; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const override; + + TENSOR_DUMPER_PRINT_TYPE(int8_t) + TENSOR_DUMPER_PRINT_TYPE(uint8_t) + TENSOR_DUMPER_PRINT_TYPE(int32_t) + TENSOR_DUMPER_PRINT_TYPE(int64_t) + TENSOR_DUMPER_PRINT_TYPE(float) + TENSOR_DUMPER_PRINT_TYPE(MLFloat16) + TENSOR_DUMPER_PRINT_TYPE(BFloat16) + TENSOR_DUMPER_PRINT_TYPE(UInt4x2) + TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 5c39cf56dfd92..40504d8648397 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -146,6 +146,31 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di } } +void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2, int dim3) { + MLDataType dataType = tensor.DataType(); + bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); + if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else { + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; + } +} + void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) { MLDataType dataType = tensor.DataType(); bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); @@ -157,8 +182,17 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else { - assert(0); + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; } } @@ -173,11 +207,24 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else { - assert(0); + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; } } +void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0) { + DumpGpuTensor(name, tensor, 1, dim0); +} + void DumpGpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); @@ -188,21 +235,33 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { std::cout << tensor.Location().ToString() << std::endl; size_t num_dims = shape.NumDimensions(); - if (num_dims >= 3) { - int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); - int dim1 = static_cast(shape[num_dims - 2]); - int dim2 = static_cast(shape[num_dims - 1]); + if (num_dims >= 4) { + int dim0 = static_cast(shape.SizeToDimension(num_dims - 4)); + int dim1 = static_cast(shape[num_dims - 3]); + int dim2 = static_cast(shape[num_dims - 2]); + int dim3 = static_cast(shape[num_dims - 1]); + DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2, dim3); + return; + } + + if (num_dims == 3) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + int dim2 = static_cast(shape[2]); DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } - auto num_items = shape.Size(); - size_t num_rows = 1; - if (num_dims > 1) { - num_rows = static_cast(shape[0]); + if (num_dims == 2) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + DumpGpuTensor(nullptr, tensor, dim0, dim1); + return; + } + + if (num_dims == 1) { + DumpGpuTensor(nullptr, tensor, static_cast(shape[0])); } - size_t row_size = num_items / num_rows; - DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } CudaTensorConsoleDumper::CudaTensorConsoleDumper() { @@ -218,93 +277,6 @@ void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); -} - void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { if (is_enabled_) DumpGpuTensor(name, tensor); @@ -335,28 +307,35 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +#define CUDA_DUMPER_PRINT_TYPE(dtype, dtype2) \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, dim2, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, gsl::span& dims) const { \ + PrintTensorByDims(this, name, reinterpret_cast(tensor), dims); \ + } -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +CUDA_DUMPER_PRINT_TYPE(int8_t, int8_t) +CUDA_DUMPER_PRINT_TYPE(uint8_t, uint8_t) +CUDA_DUMPER_PRINT_TYPE(int32_t, int32_t) +CUDA_DUMPER_PRINT_TYPE(int64_t, int64_t) +CUDA_DUMPER_PRINT_TYPE(float, float) +CUDA_DUMPER_PRINT_TYPE(MLFloat16, MLFloat16) +CUDA_DUMPER_PRINT_TYPE(BFloat16, BFloat16) +CUDA_DUMPER_PRINT_TYPE(UInt4x2, UInt4x2) +CUDA_DUMPER_PRINT_TYPE(Int4x2, Int4x2) -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +CUDA_DUMPER_PRINT_TYPE(half, MLFloat16) +#undef DUMPER_PRINT_TYPE #else CudaTensorConsoleDumper::CudaTensorConsoleDumper() { @@ -368,60 +347,6 @@ void CudaTensorConsoleDumper::Print(const std::string&) const { void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { -} - void CudaTensorConsoleDumper::Print(const char*, const Tensor&) const { } @@ -434,23 +359,27 @@ void CudaTensorConsoleDumper::Print(const char*, int, bool) const { void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { -} +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, gsl::span&) const { \ + } -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, gsl::span&) const { -} +CUDA_DUMPER_PRINT_TYPE(int8_t) +CUDA_DUMPER_PRINT_TYPE(uint8_t) +CUDA_DUMPER_PRINT_TYPE(int32_t) +CUDA_DUMPER_PRINT_TYPE(int64_t) +CUDA_DUMPER_PRINT_TYPE(float) +CUDA_DUMPER_PRINT_TYPE(MLFloat16) +CUDA_DUMPER_PRINT_TYPE(BFloat16) +CUDA_DUMPER_PRINT_TYPE(UInt4x2) +CUDA_DUMPER_PRINT_TYPE(Int4x2) +CUDA_DUMPER_PRINT_TYPE(half) +#undef DUMPER_PRINT_TYPE #endif diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 631421b1623be..406e269d6e070 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -17,43 +17,30 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - - void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; - - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const float* tensor, gsl::span& dims) const override; - - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; - - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const half* tensor, gsl::span& dims) const; - - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const BFloat16* tensor, gsl::span& dims) const; - void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; void Print(const char* name, const std::string& value, bool end_line) const override; - void Print(const std::string& value) const override; + +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const; + + CUDA_DUMPER_PRINT_TYPE(int8_t) + CUDA_DUMPER_PRINT_TYPE(uint8_t) + CUDA_DUMPER_PRINT_TYPE(int32_t) + CUDA_DUMPER_PRINT_TYPE(int64_t) + CUDA_DUMPER_PRINT_TYPE(float) + CUDA_DUMPER_PRINT_TYPE(MLFloat16) + CUDA_DUMPER_PRINT_TYPE(BFloat16) + CUDA_DUMPER_PRINT_TYPE(UInt4x2) + CUDA_DUMPER_PRINT_TYPE(Int4x2) + CUDA_DUMPER_PRINT_TYPE(half) + +#undef CUDA_DUMPER_PRINT_TYPE }; } // namespace cuda From 0d469ce0c7299e487276cafd9a9aa112699ca1d5 Mon Sep 17 00:00:00 2001 From: Jeff Kilpatrick Date: Tue, 20 May 2025 10:57:10 -0700 Subject: [PATCH 21/57] QNN EP can serialize graph to DLC (#24775) ### Description This change adds support for serializing the QNN graph to the new Deep Learning Container (DLC) format. It is meant to supplement and perhaps eventually replace use of the QnnSaver backend, which emits C++ source files when `qnn_saver_path` is set. * Add support for serializing to .dlc via the QnnIr backend. * Don't silently fallback to QnnCpu when QnnSaver was explicitly selected as the execution backend. * Minor fixes. ### Motivation and Context QNN model libraries, produced by compiling the C++ files that may be produced by QnnSaver have a number of drawbacks. Most importantly, they are not cross-platform and cannot be visualized via Netron or other tools. For these reasons, we anticipate that they may eventually be deprecated in favor of DLC files. These containers typically include a platform-agnostic representation of the graph QNN's internal representation. --------- Co-authored-by: Jeff Kilpatrick --- .../core/session/onnxruntime_c_api.h | 9 + .../qnn/builder/qnn_backend_manager.cc | 146 +++++++++-- .../qnn/builder/qnn_backend_manager.h | 66 ++++- .../qnn/builder/qnn_configs_helper.h | 9 + .../core/providers/qnn/builder/qnn_def.h | 3 +- .../qnn/builder/qnn_model_wrapper.cc | 8 +- .../providers/qnn/builder/qnn_model_wrapper.h | 2 +- .../builder/qnn_node_group/qnn_node_group.cc | 2 +- .../providers/qnn/qnn_execution_provider.cc | 103 ++++++-- .../providers/qnn/qnn_execution_provider.h | 2 +- .../test/providers/qnn/qnn_basic_test.cc | 237 ++++++++++++++---- .../test/providers/qnn/qnn_test_utils.cc | 48 +++- .../test/providers/qnn/qnn_test_utils.h | 13 + 13 files changed, 548 insertions(+), 100 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0ee8effaa16d0..a2f518ae09a4b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3719,6 +3719,7 @@ struct OrtApi { * -# "gpu" * -# "htp": Default. * -# "saver" + * -# "ir" * "backend_path": File path to QNN backend library. Mutually exclusive with "backend_type". * "profiling_level": QNN profiling level. * Available options: @@ -3740,6 +3741,14 @@ struct OrtApi { * -# "low_power_saver" * -# "power_saver" * -# "sustained_high_performance" + * "dump_qnn_ir_dlc": Use the QnnIr backend library to write .dlc files for each subgraph dispatched to QNN. When + * enabled, inference results will be incorrect. Use only for debugging. + * -# "0": Default: disabled + * -# "1": enabled + * "dump_qnn_ir_dlc_dir": Set the directory into which QnnIr will be configured to write QNN graphs as .dlc files. + * Default is current working directory. + * "qnn_ir_backend_path": File path to the QnnIr backend library. If "dump_qnn_ir_dlc" is enabled, use this path + * instead of looking for the Ir backend in the standard location. * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 0009dab837525..901569b54e049 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1,3 +1,4 @@ +// // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -14,7 +15,10 @@ #include "HTP/QnnHtpContext.h" #include "HTP/QnnHtpPerfInfrastructure.h" #include "HTP/QnnHtpSystemContext.h" +#include "IR/QnnIrCommon.h" +#include "IR/QnnIrGraph.h" #include "Saver/QnnSaver.h" +#include "Saver/QnnSaverCommon.h" #include #include "core/providers/qnn/ort_api.h" @@ -51,6 +55,93 @@ static const char* DlError() { #endif } +// Workaround for a missing comma in QNN_IR_GRAPH_CUSTOM_CONFIG_INIT. +static QnnIrGraph_CustomConfig_t EmptyIrGraphConfig() { + return { + QNN_IR_GRAPH_CONFIG_OPTION_SERIALIZATION, {QNN_IR_GRAPH_SERIALIZATION_TYPE_FLAT_BUFFER, ""}}; +} + +class QnnIrConfig : public QnnSerializerConfig { + public: + QnnIrConfig(std::string backend_path, std::string dlc_dir) + : QnnSerializerConfig(std::move(backend_path)), dlc_dir_(std::move(dlc_dir)), configs_builder_(MakeConfigsBuilder()) { + } + + const QnnGraph_Config_t** Configure() override { + auto configs_builder = MakeConfigsBuilder(); + + std::filesystem::path dlc_path = (dlc_dir_ / (GetGraphName() + ".dlc")); + std::string dlc_path_str = dlc_path.string(); + gsl::not_null dlc_path_config = configs_builder.PushCustomConfig(); + dlc_path_config->option = QNN_IR_GRAPH_CONFIG_OPTION_SERIALIZATION; + dlc_path_config->serializationOption.serializationType = QNN_IR_GRAPH_SERIALIZATION_TYPE_FLAT_BUFFER; + dlc_path_config->serializationOption.outputPath = dlc_path_str.c_str(); + + gsl::not_null dlc_path_custom_config = configs_builder.PushConfig(); + dlc_path_custom_config->option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + dlc_path_custom_config->customConfig = dlc_path_config; + + std::filesystem::create_directories(dlc_path); + + // Keep the pointer to dlc_path_str's null-terminated string alive. + std::swap(dlc_path_str, dlc_path_str_); + + std::swap(configs_builder, configs_builder_); + return configs_builder_.GetQnnConfigs(); + } + + bool SupportsArbitraryGraphConfigs() const override { + return false; + } + + private: + static QnnConfigsBuilder MakeConfigsBuilder() { + return QnnConfigsBuilder(QNN_GRAPH_CONFIG_INIT, EmptyIrGraphConfig()); + } + + std::filesystem::path dlc_dir_; + std::string dlc_path_str_; + QnnConfigsBuilder configs_builder_; +}; + +class QnnSaverConfig : public QnnSerializerConfig { + public: + QnnSaverConfig(std::string backend_path) : QnnSerializerConfig(std::move(backend_path)) {} + + const QnnGraph_Config_t** Configure() override { + return nullptr; + } + + bool SupportsArbitraryGraphConfigs() const override { + return true; + } +}; + +QnnSerializerConfig::~QnnSerializerConfig() = default; + +QnnSerializerConfig::QnnSerializerConfig(std::string backend_path) + : backend_path_(std::move(backend_path)) {} + +std::unique_ptr QnnSerializerConfig::CreateIr(std::string backend_path, std::string dlc_dir) { + return std::make_unique(std::move(backend_path), std::move(dlc_dir)); +} + +std::unique_ptr QnnSerializerConfig::CreateSaver(std::string backend_path) { + return std::make_unique(std::move(backend_path)); +} + +const std::string& QnnSerializerConfig::GetBackendPath() const { + return backend_path_; +} + +const std::string& QnnSerializerConfig::GetGraphName() const { + return graph_name_; +} + +void QnnSerializerConfig::SetGraphName(std::string graph_name) { + graph_name_ = std::move(graph_name); +} + Status ReadBinaryFromFile(const std::string& file_path, uint8_t* buffer, size_t buffer_size) { ORT_RETURN_IF(nullptr == buffer, "Binary buffer is nullptr"); std::ifstream in(file_path, std::ifstream::binary); @@ -179,6 +270,10 @@ void QnnBackendManager::SetQnnBackendType(uint32_t backend_id) { case QNN_BACKEND_ID_HTP: qnn_backend_type_ = QnnBackendType::HTP; break; + case QNN_BACKEND_ID_IR: + case QNN_BACKEND_ID_SAVER: + qnn_backend_type_ = QnnBackendType::SERIALIZER; + break; default: qnn_backend_type_ = QnnBackendType::CPU; break; @@ -209,13 +304,19 @@ Status QnnBackendManager::LoadBackend() { return Status::OK(); } +QnnSerializerConfig* QnnBackendManager::GetQnnSerializerConfig() { + return qnn_serializer_config_.get(); +} + // Loads the intended backend (e.g., HTP, CPU, etc) to get its type, and then -// sets QNN Saver as the active backend. QNN op builders will still see the intended backend (e.g., HTP) -// as the backend type to ensure they emit the expected QNN API calls. +// sets QnnSaver or QnnIr as the active backend. QNN op builders will still see the intended backend +// (e.g., HTP) as the backend type to ensure they emit the expected QNN API calls. Note, however, that +// calls to QnnBackend_validateOpConfig will be to the saver backend, not the "intended" one. // -// QNN Saver is a "debugging" backend that serializes all QNN API calls (and weights) into local files. +// QnnSaver and QnnIr are "debugging" backends that serializes all QNN API calls (and weights) into +// local files: Saver dumps to C++ sources and Ir to .dlc archives. // This information can be used to debug issues by replaying QNN API calls with another backend. -Status QnnBackendManager::LoadQnnSaverBackend() { +Status QnnBackendManager::LoadQnnSerializerBackend() { void* backend_lib_handle = nullptr; // Helper that unloads the intended backend library handle when the `unload_backend_lib` variable @@ -245,25 +346,25 @@ Status QnnBackendManager::LoadQnnSaverBackend() { auto backend_id = backend_interface_provider->backendId; SetQnnBackendType(backend_id); - // Load the QNN Saver backend and set it as the activate backend. - QnnInterface_t* saver_interface_provider{nullptr}; + // Load the serializer backend and set it as the activate backend. + QnnInterface_t* serializer_interface_provider{nullptr}; auto saver_rt = GetQnnInterfaceProvider(qnn_saver_path_.c_str(), + QnnInterface_t>(qnn_serializer_config_->GetBackendPath().c_str(), "QnnInterface_getProviders", - &backend_lib_handle_, // NOTE: QNN Saver library handle is set + &backend_lib_handle_, // NOTE: QnnSaver/Ir library handle is set {QNN_API_VERSION_MAJOR, QNN_API_VERSION_MINOR, QNN_API_VERSION_PATCH}, - &saver_interface_provider); + &serializer_interface_provider); ORT_RETURN_IF_ERROR(saver_rt); - qnn_interface_ = saver_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QNN Saver will provide the interfaces + qnn_interface_ = serializer_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QnnSaver/Ir will provide the interfaces Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); - Qnn_Version_t saver_interface_version = GetQnnInterfaceApiVersion(saver_interface_provider); + Qnn_Version_t serializer_interface_version = GetQnnInterfaceApiVersion(serializer_interface_provider); - LOGS_DEFAULT(INFO) << "Using QNN Saver version: " << saver_interface_version.major << "." - << saver_interface_version.minor << "." << saver_interface_version.patch - << " provider name : " << saver_interface_provider->providerName; + LOGS_DEFAULT(INFO) << "Using QnnSaver/Ir version: " << serializer_interface_version.major << "." + << serializer_interface_version.minor << "." << serializer_interface_version.patch + << " provider name : " << serializer_interface_provider->providerName; LOGS_DEFAULT(INFO) << "Intended backend provider name: " << backend_interface_provider->providerName << " backend id: " << backend_id @@ -636,7 +737,7 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { configs = npu_context_configs; break; case QnnBackendType::GPU: - // Currently only this works with QnnGpu. + case QnnBackendType::SERIALIZER: configs = nullptr; break; default: @@ -644,6 +745,11 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { break; } + // Not all serialization backends allow for hardware configs to be applied. + if (qnn_serializer_config_ && !qnn_serializer_config_->SupportsArbitraryGraphConfigs()) { + configs = nullptr; + } + Qnn_ContextHandle_t context = nullptr; Qnn_ErrorHandle_t result = qnn_interface_.contextCreate(backend_handle_, device_handle_, @@ -904,10 +1010,10 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, } Status status = Status::OK(); - if (qnn_saver_path_.empty()) { + if (!qnn_serializer_config_) { status = LoadBackend(); } else { - status = LoadQnnSaverBackend(); + status = LoadQnnSerializerBackend(); } if (status.IsOK()) { LOGS(logger, VERBOSE) << "LoadBackend succeed."; @@ -1287,7 +1393,7 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); - if (!qnn_saver_path_.empty()) { // Using QNN Saver backend + if (qnn_serializer_config_) { // Using QNN Saver or IR backend // QNN SDK 2.28.2 returns QNN_SAVER_ERROR_DUMMY_RETVALUE, but previous QNN versions return QNN_PROFILE_NO_ERROR. // We accept both values. ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result && QNN_SAVER_ERROR_DUMMY_RETVALUE != result, @@ -1709,13 +1815,9 @@ Status QnnBackendManager::UnloadLib(void* handle) { #ifdef _WIN32 HMODULE mod = static_cast(handle); -// TODO: QNN SDK 2.17 crashes for some models/tests on Windows x64 when unloading library. -// Example: ReductionOpTest.ArgMax -#if !defined(_M_AMD64) if (FreeLibrary(mod) == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to free library."); } -#endif // !defined(_M_AMD64) mod_handles_.erase(mod); #else auto rt = ::dlclose(handle); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 137b3856d431d..b8e8081f77f27 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -32,6 +32,62 @@ namespace qnn { class QnnModel; +class QnnSerializerConfig { + public: + virtual ~QnnSerializerConfig(); + + /** + * Create a config to write a DLC file for each graph using the Ir backend. + */ + static std::unique_ptr CreateIr(std::string backend_path, std::string dlc_dir); + + /** + * Create a config to write C++ source files using the Saver backend. + */ + static std::unique_ptr CreateSaver(std::string backend_path); + + /** + * Get the path to the serializer backend. + */ + const std::string& GetBackendPath() const; + + /** + * Set the name of the graph being serialized. This value may be used to determine the name + * of the output files. + * + * \param graph_name The name of the graph being serialized. + */ + void SetGraphName(std::string graph_name); + + /** + * Get any QNN Graph configs required to configure this serializer and perform any + * preparation, such as creating output directories. + * + * \return nullptr or a null-terminated list of QnnGraph_Config_t*. + */ + virtual const QnnGraph_Config_t** Configure() = 0; + + /** + * Some serializers allow for GraphConfigs that are unrelated to serialization to be + * specified at context creation time, while others raise an error. If true, this + * serializer should be configured with graph configs for any applicable real (e.g., HTP) + * backend. + * + * \return true if the backend can be configured with non-serialization graph configs. + */ + virtual bool SupportsArbitraryGraphConfigs() const = 0; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnSerializerConfig); + + protected: + QnnSerializerConfig(std::string backend_path); + const std::string& GetGraphName() const; + + private: + std::string backend_path_; + std::string graph_name_{"graph"}; +}; + // configuration values for QnnBackendManager creation struct QnnBackendManagerConfig { std::string backend_path; @@ -39,7 +95,7 @@ struct QnnBackendManagerConfig { ProfilingLevel profiling_level; std::string profiling_file_path; ContextPriority context_priority; - std::string qnn_saver_path; + std::shared_ptr qnn_serializer_config; uint32_t device_id; QnnHtpDevice_Arch_t htp_arch; uint32_t soc_model; @@ -63,7 +119,7 @@ class QnnBackendManager : public std::enable_shared_from_this profiling_level_(config.profiling_level), profiling_file_path_(config.profiling_file_path), context_priority_(config.context_priority), - qnn_saver_path_(config.qnn_saver_path), + qnn_serializer_config_(config.qnn_serializer_config), device_id_(config.device_id), htp_arch_(config.htp_arch), soc_model_(config.soc_model) { @@ -141,6 +197,8 @@ class QnnBackendManager : public std::enable_shared_from_this Status ParseLoraConfig(std::string lora_config); + QnnSerializerConfig* GetQnnSerializerConfig(); + private: Status LoadBackend(); @@ -176,7 +234,7 @@ class QnnBackendManager : public std::enable_shared_from_this Status LoadQnnSystemLib(); - Status LoadQnnSaverBackend(); + Status LoadQnnSerializerBackend(); Status UnloadLib(void* handle); @@ -295,7 +353,7 @@ class QnnBackendManager : public std::enable_shared_from_this #ifdef _WIN32 std::set mod_handles_; #endif - const std::string qnn_saver_path_; + const std::shared_ptr qnn_serializer_config_; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h index b581cd90537d9..74919b2bcd259 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -44,6 +44,15 @@ class QnnConfigsBuilder { return config_ptrs_.data(); } + /** + * Returns the number of configs that have been added to this builder, excluding any null terminator. + * + * \return The number of configs in this builder. + */ + size_t GetSize() const { + return IsNullTerminated() ? config_ptrs_.size() - 1 : config_ptrs_.size(); + } + /** * Creates and returns a reference to a new custom QNN configuration object. The object is initialized to * the QNN recommended default value. The caller is meant to override fields in this object. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 0d7bc0ba9f4c7..a95628ae9cc7f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -70,7 +70,8 @@ enum class QnnBackendType : uint8_t { GPU, DSP, HTP, - HTP_FP16 + HTP_FP16, + SERIALIZER, }; bool IsCpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 0f0b42bf754d7..bd22aec89102c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -53,7 +53,7 @@ bool QnnModelWrapper::IsQnnTensorWrapperExist(const std::string& name) const { return model_tensors_map_.find(name) != model_tensors_map_.end(); } -bool QnnModelWrapper::IsQnnParamExit(const std::string& param_tensor_name) const { +bool QnnModelWrapper::QnnParamExists(const std::string& param_tensor_name) const { return model_params_map_.find(param_tensor_name) != model_params_map_.end(); } @@ -121,14 +121,14 @@ bool QnnModelWrapper::AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper) { } bool QnnModelWrapper::AddParamWrapper(QnnParamWrapper&& param_wrapper) { - // Keep a copy of tensor name sine it will be moved with the wrapper into model_params_map_ + // Keep a copy of tensor name since it will be moved with the wrapper into model_params_map_ std::string param_tensor_name = param_wrapper.GetParamTensorName(); if (param_tensor_name.length() == 0) { LOGS(logger_, ERROR) << "Invalid parameter encountered empty name."; return false; } - if (IsQnnParamExit(param_tensor_name) == true) { + if (QnnParamExists(param_tensor_name) == true) { return true; } @@ -159,7 +159,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na } // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor - // We only need to creat the Qnn graph tensor during Compile to create Qnn graph + // We only need to create the Qnn graph tensor during Compile to create Qnn graph if (!do_op_validation) { std::string error_string; auto rt = it->second.CreateQnnGraphTensor(qnn_interface_, graph_, qnn_node_name, tensor_created_map_, error_string); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ec6f470af9fd..745dfde7bfac8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -279,7 +279,7 @@ class QnnModelWrapper { std::vector& tensor_wrappers, bool do_op_validation = false); - bool IsQnnParamExit(const std::string& param_tensor_name) const; + bool QnnParamExists(const std::string& param_tensor_name) const; bool CreateQnnParamTensors(const std::string& qnn_node_name, const std::vector& param_tensor_names, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 0390a305b2df9..20b37a2fb2b22 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -116,7 +116,7 @@ static Status GetQnnNodeGroupsImpl(/*out*/ std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + const std::vector& sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); sorted_qnn_node_group_indices.reserve(num_node_units); qnn_node_groups.reserve(num_node_units); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 65ef19f0b6c0e..c085ef7c31f0e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -36,18 +36,21 @@ const std::string kDefaultCpuBackendPath = MakeSharedLibraryPath("QnnCpu"); const std::string kDefaultGpuBackendPath = MakeSharedLibraryPath("QnnGpu"); const std::string kDefaultHtpBackendPath = MakeSharedLibraryPath("QnnHtp"); const std::string kDefaultSaverBackendPath = MakeSharedLibraryPath("QnnSaver"); +const std::string kDefaultIrBackendPath = MakeSharedLibraryPath("QnnIr"); static bool ParseBackendTypeName(std::string_view backend_type_name, std::string& backend_path) { constexpr std::string_view kCpuBackendTypeName{"cpu"}; constexpr std::string_view kGpuBackendTypeName{"gpu"}; constexpr std::string_view kHtpBackendTypeName{"htp"}; constexpr std::string_view kSaverBackendTypeName{"saver"}; + constexpr std::string_view kIrBackendTypeName{"ir"}; constexpr std::array kAllowedBackendTypeNames{ kCpuBackendTypeName, kGpuBackendTypeName, kHtpBackendTypeName, kSaverBackendTypeName, + kIrBackendTypeName, }; std::optional associated_backend_path{}; @@ -59,6 +62,8 @@ static bool ParseBackendTypeName(std::string_view backend_type_name, std::string associated_backend_path = kDefaultHtpBackendPath; } else if (backend_type_name == kSaverBackendTypeName) { associated_backend_path = kDefaultSaverBackendPath; + } else if (backend_type_name == kIrBackendTypeName) { + associated_backend_path = kDefaultIrBackendPath; } if (associated_backend_path.has_value()) { @@ -204,6 +209,51 @@ qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned } } +static std::unique_ptr ParseSerializerBackendOptions(const ProviderOptions& provider_options_map) { + // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. + static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; + auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY); + if (qnn_saver_path_pos != provider_options_map.end()) { + LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path_pos->second; + return qnn::QnnSerializerConfig::CreateSaver(qnn_saver_path_pos->second); + } + + static const std::string DUMP_QNN_IR_DLC = "dump_qnn_ir_dlc"; + auto dump_qnn_ir_dlc = ParseBoolOption(DUMP_QNN_IR_DLC, false, provider_options_map); + + static const std::string DUMP_QNN_IR_DLC_DIR = "dump_qnn_ir_dlc_dir"; + std::string qnn_ir_dlc_dir; + auto qnn_ir_dlc_dir_pos = provider_options_map.find(DUMP_QNN_IR_DLC_DIR); + if (qnn_ir_dlc_dir_pos != provider_options_map.end()) { + qnn_ir_dlc_dir = qnn_ir_dlc_dir_pos->second; + if (dump_qnn_ir_dlc) { + LOGS_DEFAULT(INFO) << "IR DLC directory: " << qnn_ir_dlc_dir; + } else { + LOGS_DEFAULT(WARNING) << "Provided a directory for dumping QNN graphs to DLC, " + << "but did not set dump_qnn_ir_dlc to 1."; + } + } + + static const std::string QNN_IR_BACKEND_PATH = "qnn_ir_backend_path"; + std::string qnn_ir_backend_path = kDefaultIrBackendPath; + auto qnn_ir_backend_path_pos = provider_options_map.find(QNN_IR_BACKEND_PATH); + if (qnn_ir_backend_path_pos != provider_options_map.end()) { + qnn_ir_backend_path = qnn_ir_backend_path_pos->second; + if (dump_qnn_ir_dlc) { + LOGS_DEFAULT(INFO) << "IR backend path: " << qnn_ir_backend_path; + } else { + LOGS_DEFAULT(WARNING) << "Provided a path to the Ir backend for dumping QNN graphs to DLC, " + << "but did not set dump_qnn_ir_dlc to 1."; + } + } + + if (dump_qnn_ir_dlc) { + return qnn::QnnSerializerConfig::CreateIr(std::move(qnn_ir_backend_path), std::move(qnn_ir_dlc_dir)); + } + + return nullptr; +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const ConfigOptions* config_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { @@ -283,6 +333,8 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "Using backend path: " << backend_path; } + std::unique_ptr qnn_serializer_config = ParseSerializerBackendOptions(provider_options_map); + std::string profiling_file_path; static const std::string PROFILING_LEVEL = "profiling_level"; qnn::ProfilingLevel profiling_level = qnn::ProfilingLevel::OFF; @@ -337,15 +389,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second); } - // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. - static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; - std::string qnn_saver_path; - auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY); - if (qnn_saver_path_pos != provider_options_map.end()) { - qnn_saver_path = qnn_saver_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path; - } - static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority"; qnn::ContextPriority context_priority = qnn::ContextPriority::NORMAL; auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY); @@ -464,7 +507,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio profiling_level, profiling_file_path, context_priority, - qnn_saver_path, + std::move(qnn_serializer_config), device_id_, htp_arch, soc_model}); @@ -912,7 +955,7 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod return Status::OK(); } -void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { +void QNNExecutionProvider::InitQnnHtpGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { gsl::not_null htp_graph_opt_config = configs_builder.PushCustomConfig(); @@ -956,9 +999,39 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(qnn_backend_manager_.get()); - qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, - QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); - InitQnnGraphConfigs(graph_configs_builder); + std::vector all_graph_configs; + + qnn::QnnConfigsBuilder htp_graph_configs_builder(QNN_GRAPH_CONFIG_INIT, + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); + InitQnnHtpGraphConfigs(htp_graph_configs_builder); + + const QnnGraph_Config_t** htp_configs = htp_graph_configs_builder.GetQnnConfigs(); + if (htp_configs) { + // Reserve enough for configs + nullptr + all_graph_configs.reserve(htp_graph_configs_builder.GetSize() + 1); + for (const QnnGraph_Config_t** config = htp_configs; *config; ++config) { + all_graph_configs.push_back(*config); + } + } + + qnn::QnnSerializerConfig* qnn_serializer_config = qnn_backend_manager_->GetQnnSerializerConfig(); + if (qnn_serializer_config) { + // We don't bother reserving here to keep the API simpler. Also note that if we're here, + // we're likely debugging and not waiting for inference. + qnn_serializer_config->SetGraphName(fused_node.Name()); + const QnnGraph_Config_t** serializer_configs = qnn_serializer_config->Configure(); + if (serializer_configs) { + for (const QnnGraph_Config_t** config = serializer_configs; *config; ++config) { + all_graph_configs.push_back(*config); + } + } + } + + const QnnGraph_Config_t** all_graph_configs_ptr = nullptr; + if (!all_graph_configs.empty()) { + all_graph_configs.push_back(nullptr); + all_graph_configs_ptr = all_graph_configs.data(); + } std::string json_graph_filepath; @@ -969,7 +1042,7 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vectorComposeGraph(graph_viewer, fused_node, model_settings_, logger, - graph_configs_builder.GetQnnConfigs(), json_graph_filepath)); + all_graph_configs_ptr, json_graph_filepath)); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 7769a4a453c1b..4ccb1554f8b15 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -68,7 +68,7 @@ class QNNExecutionProvider : public IExecutionProvider { void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); - void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; + void InitQnnHtpGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 0212dacadbced..2b4bbc272b482 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -26,6 +26,8 @@ using namespace onnxruntime::logging; #define ORT_MODEL_FOLDER ORT_TSTR("testdata/") +constexpr std::string_view kDlcOutputDir("dlc_output"); + // in test_main.cc extern std::unique_ptr ort_env; extern "C" void ortenv_setup(); @@ -334,19 +336,61 @@ TEST_F(QnnHTPBackendTests, RunConvInt4Model) { } #endif // #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Helper function that runs an ONNX model with a NHWC Resize operator to test that -// type/shape inference succeeds during layout transformation. -// Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. -// -// The models passed to this function are subgraphs extracted from a larger model that exhibited -// shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output -// types and shapes. -static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, - std::string htp_graph_finalization_opt_mode = "", - std::string qnn_context_priority = "", - std::string soc_model = "", - std::string htp_arch = "", - std::string device_id = "") { +enum class TestBackend { + Cpu, + Htp, + Saver, + Ir, +}; + +static std::string ToBackendLibName(TestBackend backend) { + switch (backend) { + case TestBackend::Cpu: + return "Cpu"; + case TestBackend::Htp: + return "Htp"; + case TestBackend::Saver: + return "Saver"; + case TestBackend::Ir: + return "Ir"; + default: + assert(false && "Invalid TestBackend value."); + return ""; + } +} + +static void AddSerializerConfigs(TestBackend serializer_backend, onnxruntime::ProviderOptions& options) { + std::string serializer_lib = ToBackendLibName(serializer_backend); + std::string serializer_path_key; + + switch (serializer_backend) { + case TestBackend::Ir: + serializer_path_key = "qnn_ir_backend_path"; + options["dump_qnn_ir_dlc"] = "1"; + options["dump_qnn_ir_dlc_dir"] = kDlcOutputDir; + break; + case TestBackend::Saver: + serializer_path_key = "qnn_saver_path"; + break; + default: + assert(false && "Invalid serializer backend."); + return; + } + +#if defined(_WIN32) + options[serializer_path_key] = "Qnn" + serializer_lib + ".dll"; +#else + options[serializer_path_key] = "libQnn" + serializer_lib + ".so"; +#endif +} + +static Ort::Session InitNHWCResizeModel(const ORTCHAR_T* ort_model_path, TestBackend backend, + std::optional serializer_backend = std::nullopt, + std::string htp_graph_finalization_opt_mode = "", + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -356,18 +400,18 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo onnxruntime::ProviderOptions options; options["offload_graph_io_quantization"] = "0"; + std::string backend_lib = ToBackendLibName(backend); + #if defined(_WIN32) - options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; - if (enable_qnn_saver) { - options["qnn_saver_path"] = "QnnSaver.dll"; - } + options["backend_path"] = "Qnn" + backend_lib + ".dll"; #else - options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; - if (enable_qnn_saver) { - options["qnn_saver_path"] = "libQnnSaver.so"; - } + options["backend_path"] = "libQnn" + backend_lib + ".so"; #endif + if (serializer_backend) { + AddSerializerConfigs(*serializer_backend, options); + } + if (!htp_graph_finalization_opt_mode.empty()) { options["htp_graph_finalization_optimization_mode"] = std::move(htp_graph_finalization_opt_mode); } @@ -392,6 +436,25 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo Ort::Session session(*ort_env, ort_model_path, so); + return session; +} + +// Helper function that runs an ONNX model with a NHWC Resize operator to test that +// type/shape inference succeeds during layout transformation. +// Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. +// +// The models passed to this function are subgraphs extracted from a larger model that exhibited +// shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output +// types and shapes. +static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, TestBackend backend, + std::optional serializer_backend = std::nullopt, + std::string htp_graph_finalization_opt_mode = "", + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { + Ort::Session session = InitNHWCResizeModel(ort_model_path, backend, serializer_backend, htp_graph_finalization_opt_mode, qnn_context_priority, soc_model, htp_arch, device_id); + // Input can be all zeros since we're testing for correct shape inference. std::array input0_data = {}; std::array input1_data = {}; @@ -433,25 +496,25 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo // Test shape inference of NHWC Resize operator (opset 11) that uses // the scales input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_scales_opset11) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset11.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset11.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 18) that uses // the scales input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_scales_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset18.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset18.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 11) that uses // the sizes input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset11) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset11.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset11.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 18) that uses // the sizes input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", TestBackend::Cpu); } // Test that QNN Saver generates the expected files for a model meant to run on the QNN CPU backend. @@ -463,8 +526,8 @@ TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", - false, // use_htp - true); // enable_qnn_saver + TestBackend::Cpu, // backend + TestBackend::Saver); // serializer_backend // Check that QNN Saver output files exist. EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); @@ -856,7 +919,42 @@ TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { // the sizes input. Use the QNN HTP backend. // Maps to QNN's ResizeBilinear operator. TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", TestBackend::Htp); +} + +// Test that QNN Ir generates the expected file for a model meant to run on the QNN HTP backend. + +TEST_F(QnnHTPBackendTests, QnnIr_OutputFiles) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); + if (IsIRBackendSupported() == BackendSupport::UNSUPPORTED) { + LOGS(logger, WARNING) << "QNN IR backend is not available! Skipping test."; + GTEST_SKIP(); + } else if (IsIRBackendSupported() == BackendSupport::SUPPORT_ERROR) { + LOGS(logger, ERROR) << "Failed to check if QNN IR backend is available."; + FAIL(); + } + + const std::filesystem::path qnn_dlc_dir = kDlcOutputDir; + + // Remove pre-existing QNN Ir output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_dlc_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_dlc_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Htp, // backend + TestBackend::Ir); // serializer backend + + // File names are taken from graph node names. Just make sure that we got one .dlc + // in the expected directory. + ASSERT_TRUE(std::filesystem::exists(qnn_dlc_dir)); + + int file_count = 0; + for (const auto& entry : std::filesystem::directory_iterator(qnn_dlc_dir)) { + EXPECT_TRUE(entry.is_regular_file()); + EXPECT_EQ(entry.path().extension(), ".dlc"); + ++file_count; + } + EXPECT_EQ(file_count, 1); } // Test that QNN Saver generates the expected files for a model meant to run on the QNN HTP backend. @@ -868,8 +966,8 @@ TEST_F(QnnHTPBackendTests, QnnSaver_OutputFiles) { ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", - true, // use_htp - true); // enable_qnn_saver + TestBackend::Htp, // backend + TestBackend::Saver); // serializer_backend // Check that QNN Saver output files exist. EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); @@ -885,9 +983,9 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { "3"}; // Mode 3 for (auto mode : graph_opt_modes) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - mode); // htp_graph_finalization_opt_mode + TestBackend::Htp, // backend + std::nullopt, // serializer_backend + mode); // htp_graph_finalization_opt_mode } } @@ -905,10 +1003,10 @@ TEST_F(QnnHTPBackendTests, HTPSocModels) { for (auto soc_model : soc_models) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "", // qnn_context_priority + TestBackend::Htp, // backend + std::nullopt, // serializer_backend + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority soc_model); } } @@ -920,23 +1018,23 @@ TEST_F(QnnHTPBackendTests, HTPArchValues) { "68"}; // v68 for (auto htp_arch : htp_archs) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "", // qnn_context_priority - "", // soc_model - htp_arch, // htp_arch - "0"); // device_id + TestBackend::Htp, // backend + std::nullopt, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + "", // soc_model + htp_arch, // htp_arch + "0"); // device_id } } // Test that models run with high QNN context priority. TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "high"); // qnn_context_priority + TestBackend::Htp, // use_htp + std::nullopt, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "high"); // qnn_context_priority } // Create a model with Cast + Add (quantized) @@ -1287,6 +1385,49 @@ TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { #endif // defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Test that QNN Ir generates the expected files for a model meant to run on any QNN backend. +TEST_F(QnnIRBackendTests, QnnIr_OutputFiles) { + const std::filesystem::path qnn_dlc_dir = kDlcOutputDir; + + // Remove pre-existing QNN Ir output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_dlc_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_dlc_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Ir, // backend + TestBackend::Ir); // serializer backend + + // File names are taken from graph node names. Just make sure that we got one .dlc + // in the expected directory. + ASSERT_TRUE(std::filesystem::exists(qnn_dlc_dir)); + + int file_count = 0; + for (const auto& entry : std::filesystem::directory_iterator(qnn_dlc_dir)) { + EXPECT_TRUE(entry.is_regular_file()); + EXPECT_EQ(entry.path().extension(), ".dlc"); + ++file_count; + } + EXPECT_EQ(file_count, 1); +} + +// Test that QNN Saver generates the expected files for a model meant to run on any QNN backend. +TEST(QnnSaverBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Saver, // backend + TestBackend::Saver); // serializer_backend + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 6f8a7a9ecb602..cd163b044911c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -309,11 +309,23 @@ void QnnHTPBackendTests::SetUp() { } } +static BackendSupport GetIRSupport(const onnxruntime::logging::Logger& logger); + +BackendSupport QnnHTPBackendTests::IsIRBackendSupported() const { + const auto& logger = DefaultLoggingManager().DefaultLogger(); + + if (cached_ir_support_ == BackendSupport::SUPPORT_UNKNOWN) { + cached_ir_support_ = test::GetIRSupport(logger); + } + + return cached_ir_support_; +} + // Testing helper function that calls QNN EP's GetCapability() function with a mock graph to check // if the QNN CPU backend is available. // TODO: Remove once the QNN CPU backend works on Windows ARM64 pipeline VM. -static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) { - onnxruntime::Model model("Check if CPU is available", false, logger); +static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger, const std::string& backend_type = "cpu") { + onnxruntime::Model model("Check if " + backend_type + " is available", false, logger); Graph& graph = model.MainGraph(); ModelTestBuilder helper(graph); @@ -343,7 +355,7 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) MockKernelLookup kernel_lookup; onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( - {{"backend_type", "cpu"}, {"offload_graph_io_quantization", "0"}}); + {{"backend_type", backend_type}, {"offload_graph_io_quantization", "0"}}); GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); @@ -373,6 +385,33 @@ void QnnCPUBackendTests::SetUp() { } } +static BackendSupport GetIRSupport(const onnxruntime::logging::Logger& logger) { + // QnnIr should be able to serialize any model supported by the QNN reference spec. + // Use a model that works on QnnCpu to verify QnnIr availability. + return GetCPUSupport(logger, "ir"); +} + +void QnnIRBackendTests::SetUp() { + if (cached_ir_support_ == BackendSupport::SUPPORTED) { + return; + } + + const auto& logger = DefaultLoggingManager().DefaultLogger(); + + // Determine if IR backend is supported only if we done so haven't before. + if (cached_ir_support_ == BackendSupport::SUPPORT_UNKNOWN) { + cached_ir_support_ = GetIRSupport(logger); + } + + if (cached_ir_support_ == BackendSupport::UNSUPPORTED) { + LOGS(logger, WARNING) << "QNN IR backend is not available! Skipping test."; + GTEST_SKIP(); + } else if (cached_ir_support_ == BackendSupport::SUPPORT_ERROR) { + LOGS(logger, ERROR) << "Failed to check if QNN IR backend is available."; + FAIL(); + } +} + #if defined(_WIN32) // TODO: Remove or set to SUPPORTED once HTP emulation is supported on win arm64. BackendSupport QnnHTPBackendTests::cached_htp_support_ = BackendSupport::SUPPORT_UNKNOWN; @@ -384,6 +423,9 @@ BackendSupport QnnHTPBackendTests::cached_htp_support_ = BackendSupport::SUPPORT BackendSupport QnnCPUBackendTests::cached_cpu_support_ = BackendSupport::SUPPORTED; #endif // defined(_WIN32) +BackendSupport QnnHTPBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; +BackendSupport QnnIRBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; + bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version) { static const std::unordered_map opset_with_axes_as_input = { {"ReduceMax", 18}, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 676460e108b0e..9fe48ddabd427 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1100,7 +1100,11 @@ class QnnHTPBackendTests : public ::testing::Test { protected: void SetUp() override; + // Some tests need the Ir backend, which is not always available. + [[nodiscard]] BackendSupport IsIRBackendSupported() const; + static BackendSupport cached_htp_support_; // Set by the first test using this fixture. + static BackendSupport cached_ir_support_; }; // Testing fixture class for tests that require the QNN CPU backend. Checks if QNN CPU is available before the test @@ -1113,6 +1117,15 @@ class QnnCPUBackendTests : public ::testing::Test { static BackendSupport cached_cpu_support_; // Set by the first test using this fixture. }; +// Testing fixture class for tests that require the QNN Ir backend. Checks if QNN IR is available before the test +// begins. The test is skipped if the IR backend is unavailable (may occur with certain QNN versions). +class QnnIRBackendTests : public ::testing::Test { + protected: + void SetUp() override; + + static BackendSupport cached_ir_support_; // Set by the first test using this fixture. +}; + /** * Returns true if the given reduce operator type (e.g., "ReduceSum") and opset version (e.g., 13) * supports "axes" as an input (instead of an attribute). From 7af75ee1d1a7e39fc421a292f7eed2820562b11a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 May 2025 11:43:47 -0700 Subject: [PATCH 22/57] [nodejs] fix nodejs cuda12 installation (#24814) ### Description Fix Node.js linux/x64 cuda12 installation. ### Motivation and Context --- js/node/script/install-metadata.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/node/script/install-metadata.js b/js/node/script/install-metadata.js index e0186ec45d1b4..41b905ba88eaf 100644 --- a/js/node/script/install-metadata.js +++ b/js/node/script/install-metadata.js @@ -20,15 +20,15 @@ const metadata = { 'linux/x64:cuda12': { './libonnxruntime_providers_cuda.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_cuda.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_cuda.so', }, './libonnxruntime_providers_shared.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_shared.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_shared.so', }, './libonnxruntime_providers_tensorrt.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_tensorrt.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_tensorrt.so', }, }, }, From adeb01600a9a118003dc22116b8a59d6ebb57c55 Mon Sep 17 00:00:00 2001 From: Yateng Hong Date: Wed, 21 May 2025 03:11:43 +0800 Subject: [PATCH 23/57] [TRT EP] Fix trt context memory sharing (#24784) ### Description Fixed a TRT context memory sharing bug where the context memory was assigned to a unique_ptr that was immediately destructed upon leaving scope. ### Motivation and Context The bug seems to be introduced by a refactor work: #15833 : ![image](https://github.com/user-attachments/assets/eec0e363-b6b1-4831-9ee4-a1b3ed45116c) --- .../onnxruntime/core/framework/allocator.h | 10 +++++-- .../tensorrt/tensorrt_execution_provider.cc | 29 ++++++++++--------- .../tensorrt/tensorrt_execution_provider.h | 2 ++ 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index ce7d4aaf652d0..15c15c6c143d2 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -203,7 +203,8 @@ class IAllocator { @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. */ template - static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { + static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes, + bool use_reserve = false) { ValidateAllocator(ort_allocator); size_t alloc_size = count_or_bytes; @@ -215,7 +216,12 @@ class IAllocator { alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); } - T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + T* p = nullptr; + if (use_reserve) { + p = static_cast(ort_allocator->Reserve(ort_allocator, alloc_size)); + } else { + p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + } ValidateAllocation(p, alloc_size); return IAllocatorUniquePtr{p, diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 72eb2579e9d42..c30b862395e96 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1818,6 +1818,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); + if (context_memory_) { + context_memory_.reset(); + } + if (alloc_ != nullptr) { // This code is same as OrtApis::ReleaseAllocator defined in allocator_adapters.cc. // We can't get api inside destructor so that's why we duplicate the code here. @@ -3448,17 +3452,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; #if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); #else @@ -3548,7 +3544,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], - context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, + context_memory_sharing_enable_, &max_ctx_mem_size_, &context_memory_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix, engine_hw_compatible_, @@ -3587,6 +3583,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; + auto context_memory = trt_state->context_memory; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); @@ -4031,8 +4028,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; + *context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory((*context_memory).get()); } // Start CUDA graph capture. @@ -4231,6 +4229,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_info_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, + &context_memory_, &tensorrt_mu_}; *state = p.release(); return 0; @@ -4259,6 +4258,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto context_memory = trt_state->context_memory; int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -4356,8 +4356,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con #endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; + *context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory((*context_memory).get()); } // Start CUDA graph capture. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f6c8f7d7dd46b..d2e8febea2339 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -183,6 +183,7 @@ struct TensorrtFuncState { std::vector profiles; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + IAllocatorUniquePtr* context_memory = nullptr; std::unordered_map dynamic_range_map; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; @@ -216,6 +217,7 @@ struct TensorrtShortFuncState { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + IAllocatorUniquePtr* context_memory = nullptr; std::mutex* tensorrt_mu_ptr = nullptr; }; From 2aa961ec9260ef24c0bf385cca0eff24796d3bce Mon Sep 17 00:00:00 2001 From: chenweng-quic <168707118+chenweng-quic@users.noreply.github.com> Date: Wed, 21 May 2025 06:08:16 +0800 Subject: [PATCH 24/57] [QNN EP] Add LSTM op builder for QNN EP (#24646) ### Description Add LSTM support for QNN EP ### Motivation and Context Add LSTM support for QNN EP --- .../qnn/builder/op_builder_factory.cc | 4 + .../qnn/builder/op_builder_factory.h | 2 + .../qnn/builder/opbuilder/base_op_builder.cc | 4 + .../qnn/builder/opbuilder/base_op_builder.h | 30 + .../qnn/builder/opbuilder/lstm_op_builder.cc | 807 +++++++++++ .../builder/opbuilder/upsample_op_builder.cc | 70 +- .../core/providers/qnn/builder/qnn_utils.cc | 5 +- onnxruntime/core/providers/qnn/ort_api.cc | 12 + onnxruntime/core/providers/qnn/ort_api.h | 1 + .../optimizer/graph_transform_test_builder.h | 8 + onnxruntime/test/providers/qnn/lstm_test.cc | 1177 +++++++++++++++++ 11 files changed, 2062 insertions(+), 58 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc create mode 100644 onnxruntime/test/providers/qnn/lstm_test.cc diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index efb4afcb88c85..e4d768093aa37 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -181,6 +181,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateMatMulOpBuilder("MatMul", *this); } + + { + CreateLSTMOpBuilder("LSTM", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index aa1039f857f8e..c1cc61ad19341 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -102,5 +102,7 @@ void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistratio void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index d7432f35e61cf..74518e2fcb7a2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -138,6 +138,10 @@ Status BaseOpBuilder::ProcessInt64Tensors(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } for (size_t i = 0; i < input_names.size(); i++) { + if (input_names[i].size() == 0) { + // For optional inputs, the input_name is empty + continue; + } auto& input_tensorwrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[i]); // Insert cast to int32 if input dtype is int64 if (input_tensorwrapper.GetTensorDataType() == QNN_DATATYPE_INT_64) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 5474db0590f92..5b3fa6ed3b950 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -107,6 +107,35 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger, std::vector& input_names) const ORT_MUST_USE_RESULT; + template + Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeIndex& node_index, + const std::string& node_name, + const T& scalar, + const std::string& qnn_scalar_param_name, + std::vector& param_names) const { + Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; + if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; + qnn_scalar.floatValue = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + qnn_scalar.uint32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_INT_32; + qnn_scalar.int32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; + qnn_scalar.bool8Value = static_cast(scalar); + } else { + ORT_RETURN_IF(true, "QNN EP: Unsupported scalar dtype"); + } + QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); + param_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + return Status::OK(); + } + Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, @@ -140,6 +169,7 @@ class BaseOpBuilder : public IOpBuilder { {"Less", QNN_OP_ELEMENT_WISE_LESS}, {"LessOrEqual", QNN_OP_ELEMENT_WISE_LESS_EQUAL}, {"Log", QNN_OP_ELEMENT_WISE_LOG}, + {"LSTM", QNN_OP_LSTM}, {"Max", QNN_OP_ELEMENT_WISE_MAXIMUM}, {"Min", QNN_OP_ELEMENT_WISE_MINIMUM}, {"Neg", QNN_OP_ELEMENT_WISE_NEG}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc new file mode 100644 index 0000000000000..f131d58277038 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc @@ -0,0 +1,807 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class LSTMOpBuilder : public BaseOpBuilder { + public: + LSTMOpBuilder() : BaseOpBuilder("LSTMOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LSTMOpBuilder); + + protected: + /* + ONNX LSTM inputs: + in[0]: X [seq_length, batch_size, input_size], the input sequences packed + in[1]: W [num_directions, 4*hidden_size, input_size], the weight tensor for the gates. Concatenation of W[iofc] and WB[iofc] + in[2]: R [num_directions, 4*hidden_size, hidden_size], the recurrence weight tensor. Concatenation of R[iofc] and RB[iofc] + + ONNX LSTM optional inputs: + in[3]: B [num_directions, 8*hidden_size], the bias tensor for input gate. Concatenation of [Wb[iofc], Rb[iofc]], and [WBb[iofc], RBb[iofc]] (if bidirectional) + in[4]: sequence_lens + in[5]: initial_h [num_directions, batch_size, hidden_size]. + in[6]: initial_c [num_directions, batch_size, hidden_size]. + in[7]: P [num_directions, 3*hidde_size], the weight tensor for peepholes. Concatenation of P[iof] and PB[iof] + + ONNX LSTM Parameters: + - activation_alpha ---> Not supported by QNN. + - activation_beta ---> Not supported by QNN. + - activations ---> Not supported by QNN. + - clip ---> Not supported by QNN since the clip in ONNX applied to iofc while QNN only apply to c. Refer + https://github.com/microsoft/onnxruntime/blob/v1.21.0/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc + - direction + - hidden_size + - input_forget ---> Not supported by QNN + - layout: The shape format of inputs X, initial_h, initial_c and outputs Y, Y_h, Y_c. + If 0, the following shapes are expected: + X.shape = [seq_length, batch_size, input_size], + Y.shape = [seq_length, num_directions, batch_size, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size]. + If 1, the following shapes are expected: + X.shape = [batch_size, seq_length, input_size], + Y.shape = [batch_size, seq_length, num_directions, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size]. + + ONNX LSTM optional outputs: + out[0]: Y [seq_length, num_directions, batch_size, hidden_size] = stack of out[0] from QNN_LSTM with varient directions + out[1]: Y_h [num_directions, batch_size, hidden_size] = stack of out[2] from QNN_LSTM with varient directions + out[2]: Y_c [num_directions, batch_size, hidden_size] = stack of out[1] from QNN_LSTM with varient directions + + QNN LSTM inputs: + in[0]: x_t: 2D of shape [batch_size, input_size] or + 3D of shape [time_steps, batch_size, input_size] if time_major + [batch_size, time_steps, input_size] else + in[1]: W_xf: input-to-forget weights [num_units, input_size] = ONNX in[1][direction, 2*hidden_size:3*hidden_size, :] + in[2]: W_xc: input-to-cell weights [num_units, input_size] = ONNX in[1][direction, 3*hidden_size:4*hidden_size, :] + in[3]: W_xo: input-to-output weights [num_units, input_size] = ONNX in[1][direction, 1*hidden_size:2*hidden_size, :] + in[4]: W_hf: recurrent-to-forget weights [num_units, output_size] = ONNX in[2][direction, 2*hidden_size:3*hidden_size, :] + in[5]: W_hc: recurrent-to-cell weights [num_units, output_size] = ONNX in[2][direction, 3*hidden_size:4*hidden_size, :] + in[6]: W_ho: recurrent-to-output weights [num_units, output_size] = ONNX in[2][direction, 1*hidden_size:2*hidden_size, :] + in[7]: b_f: forget gate bias [num_units] = ONNX in[3][direction, 2*hidden_size:3*hidden_size] + in[3][direction, 6*hidden_size:7*hidden_size] + in[8]: b_c: cell bias [num_units] = ONNX in[3][direction, 3*hidden_size:4*hidden_size] + in[3][direction, 7*hidden_size:8*hidden_size] + in[9]: b_o: output gate bias [num_units] = ONNX in[3][direction, 1*hidden_size:4*hidden_size] + in[3][direction, 5*hidden_size:6*hidden_size] + + # optional inputs + in[10]: h_t_init: hidden state init [batch_size, output_size] = ONNX in[5][direction] + in[11]: c_t_init: cell state init [batch_size, num_units] = ONNX in[6][direction] + in[12]: The input layer normalization weights ---> not supported on fp16 yet. + in[13]: The forget layer normalization weights ---> not supported on fp16 yet. + in[14]: The cell layer normalization weights ---> not supported on fp16 yet. + in[15]: The output layer normalization weights ---> not supported on fp16 yet. + in[16]: W_xi: input-to-input weights [num_units, input_size] = ONNX in[1][direction, 0*hidden_size:1*hidden_size, :] + in[17]: W_hi: recurrent-to-input weights [num_units, output_size] = ONNX in[2][direction, 0*hidden_size:1*hidden_size, :] + in[18]: W_ci: cell-to-input weights [num_units] = ONNX in[7][direction, 0*hidden_size:1*hidden_size] + in[19]: W_cf: cell-to-forget weights [num_units] = ONNX in[7][direction, 2*hidden_size:3*hidden_size] + in[20]: W_co: cell-to-output weights [num_units] = ONNX in[7][direction, 1*hidden_size:2*hidden_size] + in[21]: b_i: input gate bias [num_units] = ONNX in[3][direction, 0*hidden_size:1*hidden_size] + in[3][direction, 4*hidden_size:5*hidden_size] + in[22]: W_proj: projection weights [output_size, num_units] ---> not used + in[23]: b_proj: projection bias [output_size] ---> not used + in[24]: reset: Determines if the internal state should be reset ---> not used + + QNN LSTM Parameters: + - direction + - cell_clip_threshold ---> not used + - output_clip_threshold ---> not used + - time_major + - input_gate_qscale ---> not used since we fallback to fp16. + - forget_gate_qscale ---> not used since we fallback to fp16. + - cell_gate_qscale ---> not used since we fallback to fp16. + - output_gate_qscale ---> not used since we fallback to fp16. + - hidden_state_offset ---> not used since we fallback to fp16. + - hidden_state_qscale ---> not used since we fallback to fp16. + + QNN LSTM outputs: + out[0]: h_t 2D of shape [batch_size, output_size] or + 3D of shape [time_steps, batch_size, output_size] if time_major + [batch_size, time_steps, output_size] else + out[1]: c_t [batch_size, num_unit] + out[2]: o_t [batch_size, output_size] + + QNN LSTM optional outputs: + out[3]: input_gate [batch_size, num_unit] ---> not used + out[4]: forget_gate [batch_size, num_unit] ---> not used + out[5]: cell_gate [batch_size, num_unit] ---> not used + out[6]: output_gate [batch_size, num_unit] ---> not used + out[7]: hidden_state [batch_size, output_size] ---> not used + */ + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + private: + Status AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& direction, + const std::vector& input_names, + const logging::Logger& logger, + const bool& do_op_validation, + const bool& is_bidirection, + std::vector& uni_lstm_output_names) const; + Status AddStridedSliceOrReshape(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const std::vector>& ranges, + const uint32_t& begin_mask, + const uint32_t& end_mask, + const uint32_t& shrink_axes, + const uint32_t& new_axes_mask, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) const; +}; + +Status LSTMOpBuilder::AddStridedSliceOrReshape(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const std::vector>& ranges, + const uint32_t& begin_mask, + const uint32_t& end_mask, + const uint32_t& shrink_axes, + const uint32_t& new_axes_mask, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) const { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(output_name)) { + return Status::OK(); + } + // add strided_slice or reshape + // this is not general condition, only limited to caller in this builder + size_t minSize = std::min(input_shape.size(), output_shape.size()); + if (input_shape[0] == 1 && std::equal(output_shape.rbegin(), output_shape.rbegin() + minSize, input_shape.rbegin())) { + // add Reshape + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input_name, + output_name, + input_shape, + output_shape, + tensor_data_type, + quantize_param.Copy(), + quantize_param.Copy(), + do_op_validation, + is_for_input, + is_for_output)); + } else { + // add StridedSlice + // inputs + QnnTensorWrapper input_tensorwrapper(input_name, is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, quantize_param.Copy(), + std::vector(input_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), + "Failed to add input tensor for inserted StridedSlice or Reshape."); + + // params + const std::string& node_name = output_name; + + // ranges + std::vector ranges_data; + for (size_t i = 0; i < ranges.size(); i++) { + for (size_t j = 0; j < 3; j++) { + ranges_data.emplace_back(SafeInt(ranges[i][j])); + } + } + QnnParamWrapper ranges_param_wrapper(node_unit.Index(), node_name, QNN_OP_STRIDED_SLICE_PARAM_RANGES, {static_cast(ranges.size()), 3}, std::move(ranges_data), true); + std::vector param_names = { + ranges_param_wrapper.GetParamTensorName(), + }; + qnn_model_wrapper.AddParamWrapper(std::move(ranges_param_wrapper)); + + // begin_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, begin_mask, QNN_OP_STRIDED_SLICE_PARAM_BEGIN_MASK, param_names)); + + // end_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, end_mask, QNN_OP_STRIDED_SLICE_PARAM_END_MASK, param_names)); + + // shrink_axes + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, shrink_axes, QNN_OP_STRIDED_SLICE_PARAM_SHRINK_AXES, param_names)); + + // new_axes_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, new_axes_mask, QNN_OP_STRIDED_SLICE_PARAM_NEW_AXES_MASK, param_names)); + + // outputs + QnnTensorWrapper output_tensorwrapper(output_name, + is_for_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, + quantize_param.Copy(), + std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), + "Failed to add output tensor for inserted StridedSlice."); + // addNode + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_STRIDED_SLICE, {input_name}, + {output_name}, std::move(param_names), do_op_validation), + "Failed to create manually inserted Qnn StridedSlice node."); + } + + return Status::OK(); +} + +Status LSTMOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(qnn_model_wrapper); + ORT_UNUSED_PARAMETER(node_unit); + ORT_UNUSED_PARAMETER(logger); + if (node_unit.Inputs().size() > 4 && node_unit.Inputs()[4].node_arg.Exists()) { + TensorInfo tensor_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[4], tensor_info)); + + ORT_RETURN_IF_NOT(tensor_info.is_initializer, "QNN EP: dynamic sequence_length is not supported."); + + std::vector sequence_lens_bytes; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, sequence_lens_bytes)); + const size_t num_elems = sequence_lens_bytes.size() / sizeof(int32_t); + gsl::span sequence_lens{reinterpret_cast(sequence_lens_bytes.data()), num_elems}; + ORT_RETURN_IF(std::any_of(sequence_lens.begin(), + sequence_lens.end(), + [sequence_lens](int i) { return i != sequence_lens[0]; }), + "QNN EP: Only support LSTM with same sequence length."); + } + + NodeAttrHelper node_helper(node_unit); + const float clip = node_helper.Get("clip", (float)0.0); + ORT_RETURN_IF(clip != 0, + "QNN EP doesn't support non-default clip for LSTM."); + const std::vector activations = node_helper.Get("activations", std::vector{}); + ORT_RETURN_IF((activations.size() >= 3 && (activations[0] != "sigmoid" || activations[1] != "tanh" || activations[2] != "tanh")) || + (activations.size() == 6 && (activations[3] != "sigmoid" || activations[5] != "tanh" || activations[5] != "tanh")), + "QNN EP doesn't support non-default activations for LSTM."); + // TODO: Add support for layout==1 + const int64_t layout = node_helper.Get("layout", static_cast(0)); + ORT_RETURN_IF_NOT(layout == 0, + "QNN EP: Unsupport layout mode %ld for %s.", layout, node_unit.Name().c_str()); + return Status::OK(); +} + +Status LSTMOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& onnx_inputs = node_unit.Inputs(); + for (size_t i = 0; i < onnx_inputs.size(); i++) { + if (onnx_inputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, onnx_inputs[i], logger, input_names)); + } else { + input_names.emplace_back(""); + } + } + return Status::OK(); +} + +Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& direction, + const std::vector& input_names, + const logging::Logger& logger, + const bool& do_op_validation, + const bool& is_bidirection, + std::vector& uni_lstm_output_names) const { + ORT_UNUSED_PARAMETER(logger); + + const auto& onnx_inputs = node_unit.Inputs(); + const auto& onnx_outputs = node_unit.Outputs(); + const std::string& node_name = node_unit.Name(); + std::vector input_tensor_infos(onnx_inputs.size()); + for (size_t i = 0; i < onnx_inputs.size(); i++) { + if (onnx_inputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(onnx_inputs[i], input_tensor_infos[i])); + } + } + // becuase QNN LSTM three outputs are mandatory, we should provide them tensor info + std::vector output_tensor_infos(3); + for (size_t i = 0; i < 3; i++) { + if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(onnx_outputs[i], output_tensor_infos[i])); + } else { + output_tensor_infos[i].qnn_data_type = input_tensor_infos[0].qnn_data_type; + } + } + + NodeAttrHelper node_helper(node_unit); + const uint32_t hidden_size = node_helper.Get("hidden_size", 0); + const int32_t hidden_size_sign = SafeInt(hidden_size); + ORT_RETURN_IF_NOT(hidden_size > 0, "hidden size is not set for LSTM"); + const int64_t layout = node_helper.Get("layout", static_cast(0)); + + const uint32_t input_size = input_tensor_infos[0].shape[2]; + const uint32_t batch_size = layout == 0 ? input_tensor_infos[0].shape[1] : input_tensor_infos[0].shape[0]; + const uint32_t seq_length = layout == 0 ? input_tensor_infos[0].shape[0] : input_tensor_infos[0].shape[1]; + const int32_t direction_idx = input_tensor_infos[1].shape[0] < 2 || direction == "forward" ? 0 : 1; + + // params + std::vector param_names; + + // direction + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), direction == "forward" ? QNN_OP_LSTM_DIRECTION_FORWARD : QNN_OP_LSTM_DIRECTION_REVERSE, QNN_OP_LSTM_PARAM_DIRECTION, param_names)); + + // cell_clip_threshold + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_CELL_CLIP_THRESHOLD, param_names)); + + // output_clip_threshold + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_OUTPUT_CLIP_THRESHOLD, param_names)); + + // time_major + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_LSTM_PARAM_TIME_MAJOR, param_names)); + + // // input_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_INPUT_GATE_QSCALE, param_names)); + + // // forget_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_FORGET_GATE_QSCALE, param_names)); + + // // cell_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_CELL_GATE_QSCALE, param_names)); + + // // output_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_OUTPUT_GATE_QSCALE, param_names)); + + // // hidden_state_offset + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_HIDDEN_STATE_OFFSET, param_names)); + + // // hidden_state_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_HIDDEN_STATE_QSCALE, param_names)); + + // Common LSTM cell inputs + const std::string null_tensor_name = "null_tensor"; + QnnTensorWrapper null_tensor_wrapper(null_tensor_name, QNN_TENSOR_TYPE_NULL, QNN_DATATYPE_UNDEFINED, + QnnQuantParamsWrapper(), std::vector{0}); + + qnn_model_wrapper.AddTensorWrapper(std::move(null_tensor_wrapper)); + std::vector qnn_lstm_input_names(24, null_tensor_name); + + // input W + { + // QNN in[1] = ONNX in[1][direction, 2*hidden_size:3*hidden_size, :] + // QNN in[2] = ONNX in[1][direction, 3*hidden_size:4*hidden_size, :] + // QNN in[3] = ONNX in[1][direction, 1*hidden_size:2*hidden_size, :] + // QNN in[16] = ONNX in[1][direction, 0*hidden_size:1*hidden_size, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector qnn_input_indices = {1, 2, 3, 16}; + std::vector begins = {2, 3, 1, 0}; + std::vector qnn_lstm_weight_name = { + input_names[1] + "_input_to_forget_gate_weight_" + direction, + input_names[1] + "_input_to_cell_gate_weight_" + direction, + input_names[1] + "_input_to_output_gate_weight_" + direction, + input_names[1] + "_input_to_input_gate_weight_" + direction, + }; + for (size_t i = 0; i < 4; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + {0, SafeInt(input_size), 1}}; + std::vector output_shape = {hidden_size, input_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[1], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[1].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[1].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[1].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input R + { + // QNN in[4] = ONNX in[2][direction, 2*hidden_size:3*hidden_size, :] + // QNN in[5] = ONNX in[2][direction, 3*hidden_size:4*hidden_size, :] + // QNN in[6] = ONNX in[2][direction, 1*hidden_size:2*hidden_size, :] + // QNN in[17] = ONNX in[2][direction, 0*hidden_size:1*hidden_size, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector qnn_input_indices = {4, 5, 6, 17}; + std::vector begins = {2, 3, 1, 0}; + std::vector qnn_lstm_weight_name = { + input_names[2] + "_recurrent_to_forget_gate_weight_" + direction, + input_names[2] + "_recurrent_to_cell_gate_weight_" + direction, + input_names[2] + "_recurrent_to_output_gate_weight_" + direction, + input_names[2] + "_recurrent_to_input_gate_weight_" + direction}; + for (size_t i = 0; i < 4; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + {0, hidden_size_sign, 1}}; + std::vector output_shape = {hidden_size, hidden_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[2], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[2].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[2].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[2].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input B + { + // QNN in[7] = ONNX in[3][direction, 2*hidden_size:3*hidden_size] + ONNX in[3][direction, 6*hidden_size:7*hidden_size] + // QNN in[8] = ONNX in[3][direction, 3*hidden_size:4*hidden_size] + ONNX in[3][direction, 7*hidden_size:8*hidden_size] + // QNN in[9] = ONNX in[3][direction, 1*hidden_size:2*hidden_size] + ONNX in[3][direction, 5*hidden_size:6*hidden_size] + // QNN in[21] = ONNX in[3][direction, 0*hidden_size:1*hidden_size] + ONNX in[3][direction, 4*hidden_size:5*hidden_size] + uint32_t begin_mask = 0b00U; + uint32_t end_mask = 0b00U; + uint32_t shrink_axes = 0b01U; + uint32_t new_axes_mask = 0b00U; + std::vector output_shape = {hidden_size}; + std::vector qnn_lstm_bias_name = { + node_name + "_forget_gate_bias_" + direction, + node_name + "_cell_gate_bias_" + direction, + node_name + "_output_gate_bias_" + direction, + node_name + "_input_gate_bias_" + direction}; + std::vector qnn_input_indices = {7, 8, 9, 21}; + if (onnx_inputs.size() > 3 && onnx_inputs[3].node_arg.Exists()) { + std::vector begins = {2, 3, 1, 0, 6, 7, 5, 4}; + std::vector onnx_lstm_bias_name = { + input_names[3] + "_input_to_forget_gate_bias_" + direction, + input_names[3] + "_input_to_cell_gate_bias_" + direction, + input_names[3] + "_input_to_output_gate_bias_" + direction, + input_names[3] + "_input_to_input_gate_bias_" + direction, + input_names[3] + "_recurrent_to_forget_gate_bias_" + direction, + input_names[3] + "_recurrent_to_cell_gate_bias_" + direction, + input_names[3] + "_recurrent_to_output_gate_bias_" + direction, + input_names[3] + "_recurrent_to_input_gate_bias_" + direction}; + for (size_t i = 0; i < 8; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[3], + /*output_name=*/onnx_lstm_bias_name[i], + /*input_shape=*/input_tensor_infos[3].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[3].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[3].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + } + for (size_t i = 0; i < 4; i++) { + std::vector add_input_names = {onnx_lstm_bias_name[i], onnx_lstm_bias_name[i + 4]}; + // TODO: The quantize_param should not be used directly, we should calculate an approximate quant_param here. + QnnTensorWrapper add_output_tensorwrapper(qnn_lstm_bias_name[i], QNN_TENSOR_TYPE_NATIVE, input_tensor_infos[3].qnn_data_type, + input_tensor_infos[3].quant_param.Copy(), std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for inserted ElementWiseAdd node."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, + std::move(add_input_names), {qnn_lstm_bias_name[i]}, {}, do_op_validation), + "Failed to create manually inserted ElementWiseAdd node."); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_bias_name[i]; + } + } else { + // prepare zero bias + std::string zero_bias_name = node_name + "_zero_bias"; + QnnTensorWrapper zero_bias_tensor_wrapper(zero_bias_name, + QNN_TENSOR_TYPE_STATIC, + input_tensor_infos[0].qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(output_shape), + std::vector(utils::GetElementSizeByType(input_tensor_infos[0].qnn_data_type) * hidden_size, 0)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(zero_bias_tensor_wrapper)), + "Failed to add additional zero bias for QNN LSTM node."); + for (size_t i = 0; i < 4; i++) { + qnn_lstm_input_names[qnn_input_indices[i]] = zero_bias_name; + } + } + } + + // input P + if (onnx_inputs.size() > 7 && onnx_inputs[7].node_arg.Exists()) { + // QNN in[18] = ONNX in[7][direction, 0*hidden_size:1*hidden_size] + // QNN in[19] = ONNX in[7][direction, 2*hidden_size:1*hidden_size] + // QNN in[20] = ONNX in[7][direction, 1*hidden_size:1*hidden_size] + uint32_t begin_mask = 0b00U; + uint32_t end_mask = 0b00U; + uint32_t shrink_axes = 0b01U; + uint32_t new_axes_mask = 0b00U; + std::vector output_shape = {hidden_size}; + std::vector qnn_input_indices = {18, 19, 20}; + std::vector begins = {0, 2, 1}; + std::vector qnn_lstm_weight_name = { + input_names[7] + "_cell_to_input_gate_weight_" + direction, + input_names[7] + "_cell_to_forget_gate_weight_" + direction, + input_names[7] + "_cell_to_output_gate_weight_" + direction}; + for (size_t i = 0; i < 3; i++) { + std::vector> ranges = { + {direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + }; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[7], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[7].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[7].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[7].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input initial h, c + { + // QNN in[10] = ONNX in[5][direction_idx, :, :] + // QNN in[11] = ONNX in[6][direction_idx, :, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {0, SafeInt(batch_size), 1}, + {0, hidden_size_sign, 1}}; + std::vector src_indices = {5, 6}; + std::vector qnn_input_indices = {10, 11}; + std::vector output_shape = {batch_size, hidden_size}; + for (size_t i = 0; i < 2; i++) { + if (onnx_inputs.size() > src_indices[i] && onnx_inputs[src_indices[i]].node_arg.Exists()) { + std::string qnn_lstm_input_name = input_names[src_indices[i]] + "_" + direction; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[src_indices[i]], + /*output_name=*/qnn_lstm_input_name, + /*input_shape=*/input_tensor_infos[src_indices[i]].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[src_indices[i]].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[src_indices[i]].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_input_name; + } else { + // prepare zero initial values + std::string zero_initial_values_name = node_name + "_LSTM_initial_values_" + (i == 0 ? "h" : "c"); + QnnTensorWrapper zero_bias_tensor_wrapper(zero_initial_values_name, + QNN_TENSOR_TYPE_STATIC, + input_tensor_infos[0].qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(output_shape), + std::vector(utils::GetElementSizeByType(input_tensor_infos[0].qnn_data_type) * batch_size * hidden_size, 0)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(zero_bias_tensor_wrapper)), + "Failed to add additional initial values for QNN LSTM node."); + qnn_lstm_input_names[qnn_input_indices[i]] = zero_initial_values_name; + } + } + } + + // add QNN LSTM + // since HTP doesn't not support 3d yet, add #sequence_length LSTM node + std::vector qnn_all_hidden_state_names; + qnn_all_hidden_state_names.resize(seq_length); + for (uint32_t i = 0; i < seq_length; i++) { + uint32_t sequence_idx = direction == "forward" ? i : seq_length - i - 1; + // Add LSTM inputs + std::vector qnn_lstm_input_names_i = qnn_lstm_input_names; + + // input X + { + // QNN in[0] = ONNX in[0][sequence_idx, :, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector> ranges = {{SafeInt(sequence_idx), SafeInt(sequence_idx + 1), 1}, + {0, SafeInt(batch_size), 1}, + {0, SafeInt(input_size), 1}}; + std::string qnn_lstm_input_name = input_names[0] + "_cell_" + std::to_string(sequence_idx) + "_input"; + std::vector output_shape = {batch_size, input_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[0], + /*output_name=*/qnn_lstm_input_name, + /*input_shape=*/input_tensor_infos[0].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[0].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[0].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names_i[0] = qnn_lstm_input_name; + } + + // outputs + std::vector qnn_lstm_output_shape = {batch_size, hidden_size}; + + std::vector qnn_lstm_output_names = { + node_name + "_QNN_LSTM_output_all_hidden_state_" + std::to_string(sequence_idx) + "_" + direction, + node_name + "_QNN_LSTM_output_cell_state_" + std::to_string(sequence_idx) + "_" + direction, + node_name + "_QNN_LSTM_output_hidden_state_" + std::to_string(sequence_idx) + "_" + direction}; + qnn_lstm_input_names[10] = qnn_lstm_output_names[2]; // update initial_h + qnn_lstm_input_names[11] = qnn_lstm_output_names[1]; // update initial_c + qnn_all_hidden_state_names[sequence_idx] = qnn_lstm_output_names[2]; + + for (size_t j = 0; j < 3; j++) { + QnnTensorWrapper output_tensorwrapper(qnn_lstm_output_names[j], + QNN_TENSOR_TYPE_NATIVE, + output_tensor_infos[j].qnn_data_type, + output_tensor_infos[j].quant_param.Copy(), + std::vector(qnn_lstm_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), + "QNN EP: Failed to add %ldth output tensor for QNN LSTM.", j); + } + std::string lstm_node_name = node_name + "_cell_" + std::to_string(sequence_idx) + "_" + direction; + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(lstm_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_LSTM, + std::move(qnn_lstm_input_names_i), std::move(qnn_lstm_output_names), + std::vector(param_names), do_op_validation), + "QNN EP: Failed to create Qnn LSTM node."); + } + + // pack all timestamp outputs together for onnx output[0] + std::string qnn_pack_output_name = node_name + "_QNN_LSTM_output_hidden_state_all_" + direction; + + // add pack for output[0] + std::vector pack_param_names; + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), qnn_pack_output_name, 0, QNN_OP_PACK_PARAM_AXIS, pack_param_names)); + + QnnTensorWrapper pack_output_tensorwrapper(qnn_pack_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_tensor_infos[0].qnn_data_type, + output_tensor_infos[0].quant_param.Copy(), + {seq_length, batch_size, hidden_size}); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(pack_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for QNN Pack."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(qnn_pack_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_PACK, + std::move(qnn_all_hidden_state_names), {qnn_pack_output_name}, + std::move(pack_param_names), do_op_validation), + "QNN EP: Failed to create Qnn Pack node."); + + // add reshape for all outputs to align onnx output shape for unidirection + std::vector qnn_reshape_input_names = { + qnn_pack_output_name, + qnn_lstm_input_names[10], + qnn_lstm_input_names[11]}; + std::vector> qnn_lstm_output_shapes = { + {seq_length, batch_size, hidden_size}, + {batch_size, hidden_size}, + {batch_size, hidden_size}}; + // in the output shapes below, the value of 1 indicates unidirectional + std::vector> onnx_lstm_output_shapes = { + {seq_length, 1, batch_size, hidden_size}, + {1, batch_size, hidden_size}, + {1, batch_size, hidden_size}}; + for (size_t i = 0; i < 3; i++) { + if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { + const std::string reshape_output_name = is_bidirection ? qnn_reshape_input_names[i] + "_unsqueeze_" + direction : onnx_outputs[i].node_arg.Name(); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(/*input_name=*/qnn_reshape_input_names[i], + /*output_name=*/reshape_output_name, + /*input_shape=*/qnn_lstm_output_shapes[i], + /*output_shape=*/onnx_lstm_output_shapes[i], + /*tensor_data_type=*/output_tensor_infos[i].qnn_data_type, + /*quantize_param=*/output_tensor_infos[i].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/qnn_model_wrapper.IsGraphOutput(reshape_output_name))); + uni_lstm_output_names.emplace_back(reshape_output_name); + } else { + uni_lstm_output_names.emplace_back(""); + } + } + return Status::OK(); +} + +Status LSTMOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + + NodeAttrHelper node_helper(node_unit); + std::string direction = node_helper.Get("direction", "forward"); + ORT_RETURN_IF_NOT(inputs.size() >= 3 && inputs.size() <= 8, "LSTM should receive inputs ranging from 3 to 8!"); + + if (direction == "bidirectional") { + std::vector uni_lstm_output_names_forward, uni_lstm_output_names_reverse; + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, "forward", input_names, logger, do_op_validation, true, uni_lstm_output_names_forward)); + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, "reverse", input_names, logger, do_op_validation, true, uni_lstm_output_names_reverse)); + + // Concat forward and reverse output + for (size_t i = 0; i < 3; i++) { + TensorInfo output_info = {}; + if (node_unit.Outputs().size() > i && node_unit.Outputs()[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[i], output_info)); + std::string onnx_output_name = node_unit.Outputs()[i].node_arg.Name(); + + // param + std::vector concat_param_names; + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), onnx_output_name, static_cast(output_info.shape.size() - 3), QNN_OP_CONCAT_PARAM_AXIS, concat_param_names)); + + // create tensor and add op + Qnn_TensorType_t output_tensor_type = qnn_model_wrapper.IsGraphOutput(onnx_output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper concat_output_tensorwrapper(onnx_output_name, + output_tensor_type, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::vector(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(concat_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for QNN Concat."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_unit.Name(), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONCAT, + {uni_lstm_output_names_forward[i], uni_lstm_output_names_reverse[i]}, {onnx_output_name}, + std::move(concat_param_names), do_op_validation), + "QNN EP: Failed to create Qnn Concat node."); + } + } + } else { + std::vector uni_lstm_output_names; + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, direction, input_names, logger, do_op_validation, false, uni_lstm_output_names)); + } + return Status::OK(); +} + +void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc index 48214f92b1a61..cba0eb350992f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc @@ -55,18 +55,6 @@ class UpsampleOpBuilder : public BaseOpBuilder { const OnnxAttrInfo onnx_mode_attr = {"mode", "nearest"}; }; -static Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector& param_tensor_names, - const Qnn_Scalar_t& qnn_scalar, - const std::string& qnn_scalar_param_name) { - QnnParamWrapper qnn_param_wrapper(node_unit.Index(), node_unit.Name(), qnn_scalar_param_name, qnn_scalar); - param_tensor_names.push_back(qnn_param_wrapper.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); - - return Status::OK(); -} - Status UpsampleOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { @@ -161,72 +149,40 @@ Status UpsampleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model qnn_op_type = (interp_mode == "nearest") ? QNN_OP_RESIZE_NEAREST_NEIGHBOR : QNN_OP_RESIZE_BILINEAR; // Parameter 'align_corners' - Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; - qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = false; const std::string align_corners_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) ? QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_align_corners, align_corners_param_name)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, align_corners_param_name, param_tensor_names)); // Parameter 'half_pixel_centers' - Qnn_Scalar_t qnn_half_pixel_centers = QNN_SCALAR_INIT; - qnn_half_pixel_centers.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel_centers.bool8Value = false; const std::string half_pixel_centers_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) ? QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_half_pixel_centers, half_pixel_centers_param_name)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, half_pixel_centers_param_name, param_tensor_names)); if (qnn_op_type == QNN_OP_RESIZE_BILINEAR) { // Parameter 'antialias' - Qnn_Scalar_t qnn_antialias = QNN_SCALAR_INIT; - qnn_antialias.dataType = QNN_DATATYPE_BOOL_8; - qnn_antialias.bool8Value = false; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_antialias, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS, param_tensor_names)); } } else { // Remain as QNN's Resize. // Parameter 'exclude_outside' - Qnn_Scalar_t qnn_exclude_outside = QNN_SCALAR_INIT; - qnn_exclude_outside.dataType = QNN_DATATYPE_BOOL_8; - qnn_exclude_outside.bool8Value = false; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_exclude_outside, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE, param_tensor_names)); // Parameter 'transformation_mode' - Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; - qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_transformation_mode.uint32Value = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) - ? static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL) - : static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE)); + uint32_t transformation_mode = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) + ? static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL) + : static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, param_tensor_names)); // Parameter 'interpolation_mode' - Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; - qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_interp_mode.uint32Value = static_cast(supported_modes.at(interp_mode)); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE)); + uint32_t qnn_interp_mode = static_cast(supported_modes.at(interp_mode)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, param_tensor_names)); // Parameter 'nearest_mode'. Process only when 'interpolation_mode' is NEAREST. - if (qnn_interp_mode.uint32Value == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) { - Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; - qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_nearest_mode.uint32Value = static_cast(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE)); + if (qnn_interp_mode == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) { + uint32_t qnn_nearest_mode = static_cast(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE, param_tensor_names)); } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 4fe223d821f1c..cafd727c6a057 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -40,7 +40,7 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { {QNN_DATATYPE_UFIXED_POINT_8, 1}, {QNN_DATATYPE_UFIXED_POINT_16, 2}, {QNN_DATATYPE_UFIXED_POINT_32, 4}, - }; + {QNN_DATATYPE_UNDEFINED, 1}}; auto pos = data_type_to_size.find(data_type); ORT_ENFORCE(pos != data_type_to_size.end(), "Unknown QNN data type", data_type); @@ -228,6 +228,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { case QNN_DATATYPE_UFIXED_POINT_4: out << "QNN_DATATYPE_UFIXED_POINT_4"; break; + case QNN_DATATYPE_UNDEFINED: + out << "QNN_DATATYPE_UNDEFINED"; + break; default: ORT_THROW("Unknown Qnn Data type"); } diff --git a/onnxruntime/core/providers/qnn/ort_api.cc b/onnxruntime/core/providers/qnn/ort_api.cc index 809593b409dad..aec09d043d2bc 100644 --- a/onnxruntime/core/providers/qnn/ort_api.cc +++ b/onnxruntime/core/providers/qnn/ort_api.cc @@ -102,6 +102,18 @@ const std::string& NodeAttrHelper::Get(const std::string& key, const std::string return def_val; } +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + std::vector res; + for (int i = 0; i < NODE_ATTR_ITER_VAL(entry).strings_size(); i++) { + res.emplace_back(NODE_ATTR_ITER_VAL(entry).strings(i)); + } + return res; + } + + return def_val; +} + std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); diff --git a/onnxruntime/core/providers/qnn/ort_api.h b/onnxruntime/core/providers/qnn/ort_api.h index d25269be075de..2cb4d5c2003bc 100644 --- a/onnxruntime/core/providers/qnn/ort_api.h +++ b/onnxruntime/core/providers/qnn/ort_api.h @@ -151,6 +151,7 @@ class NodeAttrHelper { std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 4e50881ad4f90..26df588eab73f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -147,6 +147,14 @@ class ModelTestBuilder { } } + // Make optional tensor + NodeArg* MakeOptionalTensor() { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + std::string name; + return &graph_.GetOrCreateNodeArg(name, &type_proto); + } + template NodeArg* MakeSymbolicInput(const std::vector>& shape) { ONNX_NAMESPACE::TypeProto type_proto; diff --git a/onnxruntime/test/providers/qnn/lstm_test.cc b/onnxruntime/test/providers/qnn/lstm_test.cc new file mode 100644 index 0000000000000..4b011b9bf1108 --- /dev/null +++ b/onnxruntime/test/providers/qnn/lstm_test.cc @@ -0,0 +1,1177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/providers/tester_types.h" + +#include "core/graph/onnx_protobuf.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +/* + ONNX LSTM inputs: + in[0]: X [seq_length, batch_size, input_size] + in[1]: W [num_directions, 4*hidden_size, input_size] + in[2]: R [num_directions, 4*hidden_size, hidden_size] + + ONNX LSTM optional inputs: + in[3]: B [num_directions, 8*hidden_size] + in[4]: + in[5]: initial_h [num_directions, batch_size, hidden_size]. + in[6]: initial_c [num_directions, batch_size, hidden_size]. + in[7]: P [num_directions, 3*hidde_size] + + ONNX LSTM Parameters: + - activation_alpha ---> Not supported by QNN. + - activation_beta ---> Not supported by QNN. + - activations ---> Not supported by QNN. + - clip ---> Not supported by QNN since the clip in ONNX applied to iofc while QNN only apply to c. Refer + https://github.com/microsoft/onnxruntime/blob/v1.21.0/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc + - direction + - hidden_size + - input_forget ---> Not supported by QNN + - layout: The shape format of inputs X, initial_h, initial_c and outputs Y, Y_h, Y_c. + If 0, the following shapes are expected: + X.shape = [seq_length, batch_size, input_size], + Y.shape = [seq_length, num_directions, batch_size, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size]. + If 1, the following shapes are expected: + X.shape = [batch_size, seq_length, input_size], + Y.shape = [batch_size, seq_length, num_directions, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size]. + + ONNX LSTM optional outputs: + out[0]: Y [seq_length, num_directions, batch_size, hidden_size] + out[1]: Y_h [num_directions, batch_size, hidden_size] + out[2]: Y_c [num_directions, batch_size, hidden_size] + +*/ + +template +void _BuildLSTMTestCase(ModelTestBuilder& builder, + const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + const std::vector>& output_qparams) { + auto convert_input = [](ModelTestBuilder& builder, const TestInputDef& def) { + if (std::is_same::value) { + TestInputDef Fp16_def = ConvertToFP16InputDef(def); + return MakeTestInput(builder, Fp16_def); + } else if (std::is_same::value) { + NodeArg* input = MakeTestInput(builder, def); + QuantParams qparams = GetTestInputQuantParams(def); + return AddQDQNodePair(builder, input, qparams.scale, qparams.zero_point); + } else { + return MakeTestInput(builder, def); + } + }; + + NodeArg* inputX = convert_input(builder, X_def); + NodeArg* inputW = convert_input(builder, W_def); + NodeArg* inputR = convert_input(builder, R_def); + std::vector input_args = {inputX, inputW, inputR}; + + // optional inputs + // B + if (B_def) { + input_args.push_back(convert_input(builder, B_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // sequence length + input_args.push_back(builder.MakeOptionalTensor()); + + // H + if (H_def) { + input_args.push_back(convert_input(builder, H_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // C + if (C_def) { + input_args.push_back(convert_input(builder, C_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // P + if (P_def) { + input_args.push_back(convert_input(builder, P_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + NodeArg *lstm_output_Y, *lstm_output_Y_h, *lstm_output_Y_c; + if (has_Y) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y = builder.MakeOutput(); + } else { + lstm_output_Y = builder.MakeIntermediate(); + } + } else { + lstm_output_Y = builder.MakeOptionalTensor(); + } + + if (has_Y_h) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y_h = builder.MakeOutput(); + } else { + lstm_output_Y_h = builder.MakeIntermediate(); + } + } else { + lstm_output_Y_h = builder.MakeOptionalTensor(); + } + if (has_Y_c) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y_c = builder.MakeOutput(); + } else { + lstm_output_Y_c = builder.MakeIntermediate(); + } + } else { + lstm_output_Y_c = builder.MakeOptionalTensor(); + } + + Node& lstm_node = builder.AddNode("LSTM", + input_args, + {lstm_output_Y, lstm_output_Y_h, lstm_output_Y_c}); + lstm_node.AddAttribute("direction", direction); + lstm_node.AddAttribute("hidden_size", hidden_size); + lstm_node.AddAttribute("layout", layout); + ORT_UNUSED_PARAMETER(output_qparams); + if (std::is_same::value) { + size_t i = 0; + if (has_Y) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + if (has_Y_h) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y_h, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + if (has_Y_c) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y_c, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + } +} + +template +static GetTestModelFn BuildLSTMTestCase(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout) { + return [X_def, W_def, R_def, B_def, + H_def, C_def, P_def, + has_Y, has_Y_h, has_Y_c, + direction, hidden_size, layout](ModelTestBuilder& builder) { + _BuildLSTMTestCase(builder, X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout, {}); + }; +} + +template +static GetTestQDQModelFn BuildQDQLSTMTestCase(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout) { + return [X_def, W_def, R_def, B_def, + H_def, C_def, P_def, + has_Y, has_Y_h, has_Y_c, + direction, hidden_size, layout](ModelTestBuilder& builder, + std::vector>& output_qparams) { + _BuildLSTMTestCase(builder, X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout, output_qparams); + }; +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Runs an LSTM model on the QNN HTP backend. Checks the graph node assignment, and that inference +// outputs for QNN EP and CPU EP match. +// Note: There are accuracy on HTP in fixed point, to avoid the issue, we don't register QDQ selector for LSTM and it +// is running on HTP fp16 +template +static void RunHtpQDQLSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + TestQDQModelAccuracy(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + BuildQDQLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +static void RunHtpFp16LSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + float tolerance = 0.004f) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + + TestFp16ModelAccuracy(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +static void RunCpuFP32LSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + float tolerance = 0.004f) { + ProviderOptions provider_options; + provider_options["backend_type"] = "cpu"; + + RunQnnModelTest(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +// QNN failed to finalize when P is provided +// TODO: Add P to unit test below once finalize issue is resolved + +// HTP QDQ +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -0.5f, 0.5f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -0.5f, 0.5f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -0.5f, 0.5f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -0.5f, 0.5f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -0.5f, 0.5f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -0.5f, 0.5f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All, + 22, + QDQTolerance(0.008f)); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +// HTP Fp16 +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +// CPU FP32 +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_HC) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::nullopt, // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_P) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, true, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) From adb05e4b657409bceedde149395222178ad2848f Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Tue, 20 May 2025 16:03:09 -0700 Subject: [PATCH 25/57] [QNN EP] Fix 16x16 Conv translation (#24729) - QNN's 16x16 Conv doesn't support asymmetric int16 weight - Insert Convert Op to convert from asymmetric uint16 weight to symmetric int16 weight ### Description - QNN' Conv op doesn't support asymmetric INT16 weights. - 16x16 Conv operators in ONNX models fallback to CPU execution provider and reporting higher inference times. - Insert a Convert Op to convert asymmetric uint16 weight to symmetric int16 weight to schedule 16x16 Conv's on QNN EP provider. ### Motivation and Context - This fixes Graph execution failures for models contain 16x16 Conv op on QNN Execution provider - This also improves Inference times of model contain 16x16 Conv op --- .../qnn/builder/opbuilder/conv_op_builder.cc | 44 ++++++++++++++ .../core/providers/qnn/builder/qnn_utils.cc | 46 +++++++++++++++ .../core/providers/qnn/builder/qnn_utils.h | 11 ++++ onnxruntime/test/providers/qnn/conv_test.cc | 58 +++++++++++++++++++ 4 files changed, 159 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 23811c200213a..fbf4cbe53a812 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -280,6 +280,50 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, std::move(input_info.quant_param), std::move(actual_shape), std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + + // Workaround that inserts a QNN Convert op before input[1] (converts from quantized uint16 to signed symmetric int16) + // to avoid a QNN validation failure. + // + // QNN graph WITHOUT workaround (fails validation): + // input_0_uint16 ---> Conv ---> output_uint16 + // ^ + // | + // input_1_uint16 -----+ + // + // QNN graph WITH workaround (passes validation): + // input_0_uint16 ----------------------> Conv ---> output_uint16 + // ^ + // | + // input_1_uint16 --> Convert(to int16) --+ + + std::string weight_input_name = input_names.back(); + const auto& weight_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(weight_input_name); + + if (weight_tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_UFIXED_POINT_16) { + const auto& quant_param_wrapper = weight_tensor_wrapper.GetQnnQuantParams(); + const Qnn_QuantizeParams_t& quant_param = quant_param_wrapper.Get(); + const auto& transformed_input1_shape = weight_tensor_wrapper.GetTensorDims(); + + ORT_RETURN_IF_NOT(quant_param_wrapper.IsPerTensor(), + "Conv's INT16 weight inputs only support INT16 per-tensor quantization"); + + // Pop Conv weight. Insert Convert op after Weight + input_names.pop_back(); + const std::string& conv_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string convert_output_name = weight_input_name + "_convert_" + conv_output_name; + + ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper, + weight_input_name, + convert_output_name, + QNN_DATATYPE_UFIXED_POINT_16, + QNN_DATATYPE_SFIXED_POINT_16, + quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + transformed_input1_shape, + true, // Symmetric + do_op_validation)); + input_names.push_back(convert_output_name); + } } // diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index cafd727c6a057..f869f33847bbf 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1253,6 +1253,52 @@ Status TransposeFromCnhwToHwcn(std::vector&& original_input_shape_dims, output_buffer); } +// Inserts a QNN Convert operator to convert from one quantization type (e.g., uint16) to another (e.g., uint8). +// (OR) Convert from Asymmetric (e.g., UINT16) to Symmetric (e.g., INT16) quantization type +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool output_symmetric, + bool do_op_validation) { + // Assume input is already handled. + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); + double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); + double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); + float scale = 0.0f; + int32_t offset = 0; + ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), + static_cast(value_max), + output_qnn_data_type, + scale, + offset, + output_symmetric)); + + std::vector output_shape_copy = output_shape; + QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_qnn_data_type, + QnnQuantParamsWrapper(scale, offset), + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Convert", + {convert_input_name}, + {convert_output_name}, + {}, + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 7065a4b31f77e..eefde87630077 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -374,6 +374,17 @@ Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, std::vector& transposed_data); +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool output_symmetric, + bool do_op_validation); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index b15042a808c37..8232742f35a31 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -854,6 +854,34 @@ TEST_F(QnnHTPBackendTests, ConvU16U8_PerTensor_NoBias) { 21); // opset } +#ifndef __linux__ +// Test per-channel QDQ Conv with uint16 input[0], uint8 weights, and no bias. +// in0: u16, in1 (weight): s4, out: u8 +// Tests bug in QNN SDK 2.25 when validating Conv without a bias (QNN EP adds a dummy bias). +TEST_F(QnnHTPBackendTests, ConvU16U16_PerTensor_NoBias) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} +#endif + TEST_F(QnnHTPBackendTests, ConvU16S4_PerChannel_NoBias_LargeINT4Weight) { std::vector input_shape = {1, 3072, 1, 512}; std::vector weight_shape = {9216, 3072, 1, 1}; @@ -1309,6 +1337,36 @@ TEST_F(QnnHTPBackendTests, ConvTranspose3D_U8S8S32_PerChannel) { 13); } +#ifndef __linux__ +// Test per-channel QDQ Conv. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 +TEST_F(QnnHTPBackendTests, ConvU16S16S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // use_qdq_contrib_ops + 13); // opset +} +#endif + // Test per-channel QDQ Conv. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 TEST_F(QnnHTPBackendTests, ConvU16S8S32_PerChannel) { std::vector input_shape = {1, 2, 4, 4}; From 39767bf1fefcc1a7f802dec3692332c4a014be08 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 May 2025 09:44:48 -0700 Subject: [PATCH 26/57] Remove unused tensor dumper functions (#24821) ### Description Remove unused tensor dumper functions. Those functions are not needed any more since it is easy to make a string with `::onnxruntime::MakeString` like in `DUMP_CPU_STRING` macros. ### Motivation and Context Follow up with https://github.com/microsoft/onnxruntime/pull/24813#discussion_r2096803842. Some functions were added, but not used any more. Remove them to avoid maintenance cost. --- .../contrib_ops/cpu/utils/console_dumper.h | 3 -- .../contrib_ops/cpu/utils/dump_tensor.cc | 39 ------------------- .../contrib_ops/cpu/utils/dump_tensor.h | 4 -- .../cuda/utils/dump_cuda_tensor.cc | 34 ---------------- .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 3 -- 5 files changed, 83 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 9f3d22b9b3c0f..0c1d6a95dff20 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -18,11 +18,8 @@ class IConsoleDumper { void Disable() { is_enabled_ = false; } bool IsEnabled() const { return is_enabled_; } - virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const Tensor& value) const = 0; virtual void Print(const char* name, const OrtValue& value) const = 0; - virtual void Print(const char* name, int index, bool end_line) const = 0; - virtual void Print(const char* name, const std::string& value, bool end_line) const = 0; virtual void Print(const std::string& value) const = 0; #define TENSOR_DUMPER_PRINT_TYPE(dtype) \ diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 7cbf989a44878..947311b89fbfd 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -224,36 +224,6 @@ void CpuTensorConsoleDumper::Print(const char* name, const OrtValue& value) cons Print(name, tensor); } -void CpuTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { - if (!is_enabled_) - return; - - std::unique_lock lock(s_mutex); - std::cout << std::string(name) << "[" << index << "]"; - - if (end_line) { - std::cout << std::endl; - } -} - -void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, bool end_line) const { - if (!is_enabled_) - return; - - std::unique_lock lock(s_mutex); - std::cout << std::string(name) << "=" << value; - - if (end_line) { - std::cout << std::endl; - } -} - -void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - #define TENSOR_DUMPER_PRINT_TYPE(dtype) \ void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ if (is_enabled_) \ @@ -296,15 +266,6 @@ void CpuTensorConsoleDumper::Print(const char*, const Tensor&) const { void CpuTensorConsoleDumper::Print(const char*, const OrtValue&) const { } -void CpuTensorConsoleDumper::Print(const char*, int, bool) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { -} - #define TENSOR_DUMPER_PRINT_TYPE(dtype) \ void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ } \ diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index 5066c3ddbb4b3..6de0439d7f8ba 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -15,12 +15,8 @@ class CpuTensorConsoleDumper : public IConsoleDumper { CpuTensorConsoleDumper(); virtual ~CpuTensorConsoleDumper() {} - void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; - void Print(const char* name, int index, bool end_line) const override; - void Print(const char* name, const std::string& value, bool end_line) const override; void Print(const std::string& value) const override; diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 40504d8648397..b986f0ae3edad 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -272,11 +272,6 @@ void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { if (is_enabled_) DumpGpuTensor(name, tensor); @@ -287,26 +282,6 @@ void CudaTensorConsoleDumper::Print(const char* name, const OrtValue& value) con Print(name, tensor); } -void CudaTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { - if (!is_enabled_) - return; - - std::cout << std::string(name) << "[" << index << "]"; - if (end_line) { - std::cout << std::endl; - } -} - -void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, bool end_line) const { - if (!is_enabled_) - return; - - std::cout << std::string(name) << "=" << value; - if (end_line) { - std::cout << std::endl; - } -} - #define CUDA_DUMPER_PRINT_TYPE(dtype, dtype2) \ void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ if (is_enabled_) \ @@ -344,21 +319,12 @@ CudaTensorConsoleDumper::CudaTensorConsoleDumper() { void CudaTensorConsoleDumper::Print(const std::string&) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { -} - void CudaTensorConsoleDumper::Print(const char*, const Tensor&) const { } void CudaTensorConsoleDumper::Print(const char*, const OrtValue&) const { } -void CudaTensorConsoleDumper::Print(const char*, int, bool) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const { -} - #define CUDA_DUMPER_PRINT_TYPE(dtype) \ void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ } \ diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 406e269d6e070..ec034bc15341e 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -16,11 +16,8 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; - void Print(const char* name, int index, bool end_line) const override; - void Print(const char* name, const std::string& value, bool end_line) const override; void Print(const std::string& value) const override; #define CUDA_DUMPER_PRINT_TYPE(dtype) \ From fbc6b23178294bcecb8ae152e179b3a4a2e6f680 Mon Sep 17 00:00:00 2001 From: kuanyul-quic Date: Thu, 22 May 2025 00:46:58 +0800 Subject: [PATCH 27/57] [QNN EP] Fix inconsistent inputs for graph (#24751) ### Description - Match the graph input correctly - Add GetGraphInputNumber function ### Motivation and Context - The number of graph inputs and the number of tensor wrappers may not match. - For example, for ResizeNearestNeighbor op, Qnn only cares about the 1st input, so the rest of the inputs are not converted to tensor wrappers. However, these remaining inputs still appear in the graph inputs, resulting in a discrepancy in the input quantities. --- .../core/providers/qnn/builder/qnn_model.cc | 22 +++++++++++++- .../core/providers/qnn/builder/qnn_model.h | 5 ++++ .../test/providers/qnn/qnn_basic_test.cc | 29 +++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index ec84820bb7896..175a76b590895 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -331,7 +331,15 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, bool is_input) { size_t tensor_count = tensor_wrappers.size(); ORT_RETURN_IF(0 == tensor_count, "Zero tensor size!"); - qnn_tensor_infos.resize(tensor_count); + if (is_input) { + // Resize qnn_tensor_infos according to the number of graph inputs. + auto input_count = GetGraphInputCount(); + ORT_RETURN_IF(input_count < tensor_count, + "The count of graph inputs should be at least the count of tensor_wrapper!"); + qnn_tensor_infos.resize(input_count); + } else { + qnn_tensor_infos.resize(tensor_count); + } for (auto& tensor_wrapper : tensor_wrappers) { ORT_RETURN_IF(utils::QnnTensorHasDynamicShape(tensor_wrapper.GetQnnTensor()), @@ -348,6 +356,18 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_info.tensor_byte_size = static_cast(length); qnn_tensor_info.ort_index = ort_index; } + // The number of graph inputs and the number of tensor wrappers may not match. + // - For example, for ResizeNearestNeighbor op, Qnn only cares about the 1st input, + // so the rest of the inputs are not converted to tensor wrappers. + // - However, these remaining inputs still appear in the graph inputs, resulting in + // a discrepancy in the input quantities. + // If not all inputs are used, erase the empty allocations in qnn_tensor_infos. + if (is_input) { + qnn_tensor_infos.erase(std::remove_if(qnn_tensor_infos.begin(), + qnn_tensor_infos.end(), + [](QnnTensorInfo qnn_tensor_info) { return qnn_tensor_info.tensor_wrapper == nullptr; }), + qnn_tensor_infos.end()); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 6f7738f554ef0..9f10b319f1a57 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -77,6 +77,11 @@ class QnnModel { return it->second; } + // Return the number of graph inputs + size_t GetGraphInputCount() const { + return model_input_index_map_.size(); + } + size_t GetOutputIndex(const std::string& name) const { return GetInputOutputIndex(name, outputs_info_); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 2b4bbc272b482..a206644bc945e 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -6,6 +6,7 @@ #include #include "core/graph/constants.h" +#include "core/graph/node_attr_utils.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #if BUILD_QNN_EP_STATIC_LIB #include "core/providers/qnn/qnn_allocator.h" // Used by QnnHTPBackendTests.UseHtpSharedMemoryAllocatorForInputs @@ -1384,6 +1385,34 @@ TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { } #endif // defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB +// Test whether QNN EP can handle the case where the number of graph inputs and +// the number of tensor wrappers do not match. +// Take Resize op as an example. +// - Qnn only cares about the 1st input, so the rest of the inputs are not converted +// to tensor wrappers. +// - However, these remaining inputs still appear in the graph inputs, +// resulting in a discrepancy in the input quantities. +TEST_F(QnnHTPBackendTests, TestMismatchedGraphInputAndTensorWrapperCount) { + onnxruntime::ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + + auto input_defs = {TestInputDef({1, 3, 10, 10}, false, -10.0f, 10.0f), + TestInputDef({0}, false, {}), + TestInputDef({4}, true, {1.0f, 1.0f, 2.0f, 2.0f})}; + auto attrs = {utils::MakeAttribute("mode", "nearest"), + utils::MakeAttribute("coordinate_transformation_mode", "asymmetric"), + utils::MakeAttribute("nearest_mode", "floor")}; + RunQnnModelTest(BuildOpTestCase("Resize", + input_defs, + {}, + attrs, + kOnnxDomain), + provider_options, + 11, + ExpectedEPNodeAssignment::All, + 0.008f); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Test that QNN Ir generates the expected files for a model meant to run on any QNN backend. From 4f208b347e534093470703d37da9c06289943da2 Mon Sep 17 00:00:00 2001 From: Ishwar Raut Date: Wed, 21 May 2025 11:29:01 -0700 Subject: [PATCH 28/57] [NV TensorRt RTX EP] : Fix Domain check. (#24816) ### Description Small change to remove the MS Domain check on onnx model nodes ### Motivation and Context The check returns unsupported for some nodes having an MS Domain. Trt RTX supports some MS domain ops. if return unsupported these ops falls back to CPU EP @ankan-ban @chilo-ms @gedoensmax @jywu-msft Co-authored-by: iraut --- .../core/providers/nv_tensorrt_rtx/nv_execution_provider.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 6a7ff63dbc0ed..8b864c42714cb 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1989,10 +1989,6 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) { supported_node = false; } - // Exclude contrib ops - if (node->Domain() == kMSDomain) { - supported_node = false; - } if (supported_node) { if (new_subgraph) { From 577eac9a9e2340da2c782d0ca283cff27dfd7ee7 Mon Sep 17 00:00:00 2001 From: Ashwath Shankarnarayan Date: Wed, 21 May 2025 13:11:25 -0700 Subject: [PATCH 29/57] [QNN EP] MaxPool input rank-3 auto pad bug fix (#24827) - Previously, padding for rank-3 MaxPool was only computed for auto_pad="NOTSET", using the final output shape. - Identified a broader issue during auto_pad="VALID" implementation: padding must be derived from the recalculated output shape. - Added unit tests to cover all use cases of auto_pad. - Enabled the failing unit test in the cpu pool test ### Description This PR fixes an issue in the padding calculation logic for rank-3 MaxPool operations when using auto_pad. The bug stemmed from using the final output shape (rank-3) to compute padding, rather than the correct intermediate shape (rank-4) that MaxPool actually operates on. The logic has been updated to use the reshaped rank-4 output for accurate padding computation. Unit tests have been added to validate behavior across all auto_pad modes. ### Motivation and Context While implementing support for auto_pad="VALID" in MaxPool, we discovered that the padding for MaxPool rank-3 was being calculated using the final output shape, which is rank-3. However, MaxPool internally operates on a reshaped rank-4 tensor (via pre- and post-processing reshapes). As a result, the padding logic was misaligned with the actual shape used during pooling, leading to test failures. --- .../qnn/builder/opbuilder/pool_op_builder.cc | 76 ++++++++++--------- .../test/providers/cpu/nn/pool_op_test.cc | 3 +- .../test/providers/qnn/pool_op_test.cpp | 45 +++++++++++ 3 files changed, 88 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index ba9f7baa4c1ee..f932858eb2fd9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -103,6 +103,36 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +static std::vector AmendOutputShapeForRank3Pool( + gsl::span input_shape, // {N, H, W, C} + gsl::span kernel_shape, // {k_h, k_w} + gsl::span strides, // {s_h, s_w} + gsl::span pads) { + assert(input_shape.size() == 4 && + kernel_shape.size() == 2 && + strides.size() == 2 && + pads.size() == 4); + + const uint32_t N = input_shape[0]; + const uint32_t H = input_shape[1]; + const uint32_t W = input_shape[2]; + const uint32_t C = input_shape[3]; + + // pad the spatial dims + uint32_t padded_H = H + pads[0] + pads[2]; + uint32_t padded_W = W + pads[1] + pads[3]; + + // floor-mode on NHWC + uint32_t out_H = (padded_H < kernel_shape[0]) + ? 0 + : (padded_H - kernel_shape[0]) / strides[0] + 1; + uint32_t out_W = (padded_W < kernel_shape[1]) + ? 0 + : (padded_W - kernel_shape[1]) / strides[1] + 1; + + return {N, out_H, out_W, C}; +} + Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, std::vector& pad_amount, std::vector& strides, @@ -153,6 +183,14 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, dilations = raw_dilations; } + // Max Pool rank 3 input + if (output_shape.size() == 3) { + // Calculate MaxPool output for rank-4 when input is rank 3 + output_shape = AmendOutputShapeForRank3Pool(input_shape, + filter_size, + strides, + pad_amount); + } auto total_pads_0 = (output_shape[1] - 1) * strides[0] + (filter_size[0] - 1) * dilations[0] + 1 - input_shape[1]; auto total_pads_1 = (output_shape[2] - 1) * strides[1] + (filter_size[1] - 1) * dilations[1] + 1 - input_shape[2]; if (auto_pad.compare("SAME_LOWER") != 0) { @@ -189,36 +227,6 @@ void SetPoolParam(const NodeUnit& node_unit, qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); } -std::vector ComputePoolOutputShape( - const std::vector& input_shape, // {N, H, W, C} - const std::vector& kernel_shape, // {k_h, k_w} - const std::vector& strides, // {s_h, s_w} - const std::vector& pads) { - assert(input_shape.size() == 4 && - kernel_shape.size() == 2 && - strides.size() == 2 && - pads.size() == 4); - - const uint32_t N = input_shape[0]; - const uint32_t H = input_shape[1]; - const uint32_t W = input_shape[2]; - const uint32_t C = input_shape[3]; - - // pad the spatial dims - uint32_t padded_H = H + pads[0] + pads[2]; - uint32_t padded_W = W + pads[1] + pads[3]; - - // floor-mode on NHWC - uint32_t out_H = (padded_H < kernel_shape[0]) - ? 0 - : (padded_H - kernel_shape[0]) / strides[0] + 1; - uint32_t out_W = (padded_W < kernel_shape[1]) - ? 0 - : (padded_W - kernel_shape[1]) / strides[1] + 1; - - return {N, out_H, out_W, C}; -} - Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -316,10 +324,10 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } // Calculate MaxPool output for rank-4 when input is rank 3 - auto pooled_shape = ComputePoolOutputShape(onnx_in_shape, - filter_size, - stride, - pad_amount); + auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, + filter_size, + stride, + pad_amount); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper); diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 8edbd417544c4..1df640a84a64d 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -230,8 +230,7 @@ TEST(PoolTest, MaxPool1D_case2) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - // QNN test failed. Caused by a combination of most recent changes, will fix it - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool1D_case3) { diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index d777b1134d060..9284df6f8a4a8 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -262,6 +262,51 @@ TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8) { ExpectedEPNodeAssignment::All); } +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='VALID' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_VALID) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "VALID")}, + ExpectedEPNodeAssignment::All); +} + +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='SAME_UPPER' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_SAME_UPPER) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "SAME_UPPER")}, + ExpectedEPNodeAssignment::All); +} + +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='SAME_LOWER' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_SAME_LOWER) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "SAME_LOWER")}, + ExpectedEPNodeAssignment::All); +} + TEST_F(QnnHTPBackendTests, MaxPool_Ceil_HTP_u8) { RunQDQPoolOpTest("MaxPool", TestInputDef({1, 2, 3, 3}, false, -10.0f, 10.0f), // Dynamic input with range [-10, 10] From cd2502a0e8a57810bb354f00c65ce4377e60a549 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 May 2025 09:24:10 -0700 Subject: [PATCH 30/57] Update Qnn default version to 2.34.0.250424 (#24750) ### Description Update Qnn default version to 2.34.0.250424 --- onnxruntime/test/onnx/TestCase.cc | 6 +++++ .../test/providers/cpu/math/gemm_test.cc | 26 +++++++++++++------ .../test/providers/qnn/gemm_op_test.cc | 9 ++++--- ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../custom-nuget-packaging-pipeline.yml | 2 +- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../azure-pipelines/py-packaging-pipeline.yml | 2 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../stages/py-cpu-packaging-stage.yml | 2 +- .../templates/android-java-api-aar-test.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../azure-pipelines/templates/c-api-cpu.yml | 2 +- .../templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../templates/jobs/download_win_qnn_sdk.yml | 2 +- .../templates/py-linux-qnn.yml | 2 +- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-arm64ec-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 4 +-- .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 22 files changed, 50 insertions(+), 31 deletions(-) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index f56f9ffcc7858..6dca258601339 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1455,6 +1455,12 @@ std::unique_ptr> GetBrokenTests(const std::string& provider // Fails with QNN SDK 2.17.0: // expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ broken_tests->insert({"facedetection_op8_qdq", "result differs"}); + // Fails with QNN SDK 2.34.0: + // expected 2.18661 (400bf164), got 1.48898 (3fbe96ce), diff: 0.697631, tol=0.00318661 idx=0. 8 of 8 differ + broken_tests->insert({"gemm_default_vector_bias", "result differs with 2.34"}); + // expected 0.0505495 (3d4f0d00), got 0.0506369 (3d4f68ae), diff: 8.74326e-05, tol=6.05495e-05 idx=448 + broken_tests->insert({"mobilenetv2-1.0", "result differs with 2.34"}); + broken_tests->insert({"facedetection_op8", "segfault with CPU backend, will be fixed by QNN 2.36"}); #if defined(_WIN32) && defined(_M_AMD64) // Fails with QNN SDK 2.17.0 on Windows x64: diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 400b5ab20930c..6abb3d62848f2 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -430,6 +430,7 @@ TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) { {static_cast(11.0f), static_cast(12.0f), static_cast(13.0f), static_cast(-9.0f), static_cast(-8.0f), static_cast(-7.0f)}); test.Config(run_with_tunable_op) + .ConfigExcludeEps({kQnnExecutionProvider}) // Accuracy issues with QNN CPU backend since QNN 2.34 .RunWithConfig(); } @@ -518,10 +519,8 @@ TYPED_TEST(GemmOpTypedTests, TestGemmBroadcast) { excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - if (b_is_initializer && !c_is_initializer) { - // Accuracy issues on QNN's CPU backend with QNN SDK version 2.17 - excluded_providers.insert(kQnnExecutionProvider); - } + // Accuracy issues with QNN CPU backend since QNN 2.34 + excluded_providers.insert(kQnnExecutionProvider); test.ConfigExcludeEps(excluded_providers) .Config(run_with_tunable_op) @@ -553,10 +552,16 @@ TYPED_TEST(GemmOpTypedTests, TestGemmTrans) { test.AddOutput("Y", {2, 3}, {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + + std::unordered_set excluded_providers; #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - test.Config(run_with_tunable_op) + // Accuracy issues with QNN CPU backend since QNN 2.34 + excluded_providers.insert(kQnnExecutionProvider); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) .RunWithConfig(); } @@ -579,10 +584,15 @@ TYPED_TEST(GemmOpTypedTests, TestGemmTransB) { test.AddOutput("Y", {2, 3}, {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + + std::unordered_set excluded_providers; #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - test.Config(run_with_tunable_op) + excluded_providers.insert(kQnnExecutionProvider); // Accuracy issues with QNN CPU backend since QNN 2.34 + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) .RunWithConfig(); }; run_test(false, false); diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index a7c86806bf426..fbaf997b476da 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -73,8 +73,9 @@ TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { ExpectedEPNodeAssignment::All); // Assigned to QNN EP. } +// since Qnn v2.34 value pair (120.73912, 121.73912) at index #0 don't match, which is 1 from 120.739 // Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). -TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Dynamic_A_B_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); @@ -110,8 +111,9 @@ TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { ExpectedEPNodeAssignment::All); } +// Since Qnn 2.34 value pair (29.4347763, 30.4347763) at index #0 don't match, which is 1 from 29.4348 // Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. -TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_TransAB_Dynamic_B_And_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); @@ -123,7 +125,8 @@ TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicInputs) { +// Since Qnn 2.34 value pair (11, 10) at index #0 don't match, which is -1 from 11 +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Broadcast_Bias_DynamicInputs) { std::vector input_a_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; std::vector input_b_data(12, 1.0f); std::vector input_c_data = {1.0f, 2.0f, 3.0f}; diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ba6a33b07e765..ab10bdfba0e0f 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3006eebd2d3b5..7f8039d237731 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.0.250327 + default: 2.34.0.250424 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index b1a7c92dc3529..6ee64e4870fd5 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index f08fd70d6d6cf..580f565310661 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index d19f9bde7ad75..035b4b6c17222 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.33.2.250410 + default: 2.34.0.250424 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 722a3162cfed8..63fb41ab24c68 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 9928a68b6df06..84445b117b495 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.33.2.250410 + default: 2.34.0.250424 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index d1fa72d7e4413..0c70a4f82c566 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.33.0.250327' + default: '2.34.0.250424' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 4474a6b45ef58..c94969d9e9d41 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.33.0.250327' + default: '2.34.0.250424' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 9f65fc8891e94..a6cf5b9a7713e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.33.0.250327 + default: 2.34.0.250424 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index e00e40b80b723..01dbfc5292aa9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.33.2.250410' + default: '2.34.0.250424' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 3b27060b3fcec..13cc9314caf77 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.33.2.250410' + default: '2.34.0.250424' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index c361fe678699e..a0bfd6a46a43c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index c1f47de63c38c..d28b3e9604c5d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 1a00d67bdbb2a..f300d845579bf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 72c8323d032ed..ce22142e6c5bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index d739724f8744a..0b8c493ae124d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.33.2.250410' + QnnSdk: '2.34.0.250424' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -125,4 +125,4 @@ stages: displayName: 'Publish Pipeline Qnn NuGet Artifact' inputs: artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)' \ No newline at end of file + targetPath: '$(Build.ArtifactStagingDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 93a9909e529f8..9c06edb4d03e8 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index b83621d285f9a..3b41394b97bd3 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: 'BUILD_QNN_EP' From ad7b0e368d453e4e6db646e141ef33e38ff83a7a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 22 May 2025 11:45:47 -0700 Subject: [PATCH 31/57] [build] disable vcpkg for Dawn temporarily (#24838) ### Description ### Motivation and Context --- .../workflows/macos-ci-build-and-test-workflow.yml | 2 +- cmake/CMakeLists.txt | 6 +++++- cmake/external/onnxruntime_external_deps.cmake | 14 ++++++++++++-- cmake/onnxruntime_java.cmake | 6 +++++- cmake/onnxruntime_nodejs.cmake | 6 +++++- cmake/onnxruntime_providers_webgpu.cmake | 12 ++++++++++-- cmake/onnxruntime_python.cmake | 6 +++++- cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch | 13 +++++++++++++ cmake/vcpkg.json | 2 +- 9 files changed, 57 insertions(+), 10 deletions(-) create mode 100644 cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 9e276751bd3d0..dfe97f8370e99 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -51,7 +51,7 @@ jobs: --build_objc --build_java --build_wheel - ${{ inputs.use_webgpu && '--use_webgpu --cmake_extra_defines onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON' || '' }} + ${{ inputs.use_webgpu && '--use_webgpu' || '' }} ${{ inputs.use_xnnpack && '--use_xnnpack' || '' }} ${{ inputs.use_coreml && '--use_coreml' || '' }} --use_vcpkg --use_vcpkg_ms_internal_asset_cache diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2451fe9f4008b..08aed0cb296a2 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1067,7 +1067,11 @@ if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) - if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (FALSE) if (NOT onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=OFF") endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4ed74f1315cb5..e2b0432f3e8e1 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -625,7 +625,11 @@ endif() if (onnxruntime_USE_WEBGPU) - if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (FALSE) # vcpkg does not support Emscripten yet find_package(dawn REQUIRED) else() @@ -739,7 +743,13 @@ if (onnxruntime_USE_WEBGPU) # - (private) Force enable f16 support for NVIDIA Vulkan # Dawn disabled f16 support for NVIDIA Vulkan by default because of crashes in f16 CTS tests (crbug.com/tint/2164). # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch) + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch && + + # The dawn_fix_copy_dxil_dll.patch contains the following changes: + # + # - (private) Fix copy of dxil.dll in Dawn + # The patch ensures the copy of dxil.dll to be done after the build step of `dxcompiler` target. + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_fix_copy_dxil_dll.patch) onnxruntime_fetchcontent_declare( dawn diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 25ceb63df1f19..a65bd9373d1b7 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -177,7 +177,11 @@ if (WIN32) endif() if (onnxruntime_USE_WEBGPU) if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - if (onnxruntime_USE_VCPKG) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) add_custom_command( TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index 4e09400ac84b8..54ac045ce135f 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -74,7 +74,11 @@ endif() if (onnxruntime_USE_WEBGPU) set(NODEJS_BINDING_USE_WEBGPU "--use_webgpu") if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - if (onnxruntime_USE_VCPKG) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) list(APPEND NODEJS_DLL_DEPS "$") list(APPEND NODEJS_DLL_DEPS "$") else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 7a7e0d39fcd2d..a8a79bc928dd1 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -59,7 +59,11 @@ list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") endif() - if (onnxruntime_USE_VCPKG) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) # Fix Dawn vcpkg build issue (missing IMPORTED_IMPLIB and IMPORTED_LOCATION for target dawn::webgpu_dawn) get_target_property(webgpu_dawn_target_IMPORTED_IMPLIB dawn::webgpu_dawn IMPORTED_IMPLIB) if (NOT webgpu_dawn_target_IMPORTED_IMPLIB) @@ -82,7 +86,11 @@ if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) # Ensure dxil.dll and dxcompiler.dll exist in the output directory $ - if (onnxruntime_USE_VCPKG) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) find_package(directx-dxc CONFIG REQUIRED) target_link_libraries(onnxruntime_providers_webgpu Microsoft::DirectXShaderCompiler) target_link_libraries(onnxruntime_providers_webgpu Microsoft::DXIL) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 7b91e65306bdb..cf5a6c78b925c 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1075,7 +1075,11 @@ endif() if (onnxruntime_USE_WEBGPU) if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - if (onnxruntime_USE_VCPKG) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy diff --git a/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch b/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch new file mode 100644 index 0000000000000..cd4d53b4cbdb7 --- /dev/null +++ b/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch @@ -0,0 +1,13 @@ +diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt +index cdfde38819..fc5ff76421 100644 +--- a/third_party/CMakeLists.txt ++++ b/third_party/CMakeLists.txt +@@ -352,6 +352,8 @@ function(AddSubdirectoryDXC) + TARGET copy_dxil_dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${DXIL_DLL_PATH} $ + COMMENT "Copying ${DXIL_DLL_PATH} to $") ++ # Ensure folder "$" exists when copying the dll ++ add_dependencies(copy_dxil_dll dxcompiler) + # Make dxc target depend on copy_dxil_dll + add_dependencies(dxc copy_dxil_dll) + endif() diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 193ba6fe5cad5..7c6b2fed36d1b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -93,7 +93,7 @@ }, "webgpu-ep": { "description": "Build with WebGPU EP", - "dependencies": [{ "name": "dawn", "platform": "!emscripten" }] + "dependencies": [] } }, "overrides": [ From 5e2adc9edfd0de93678fb8f664f7b90a63676e01 Mon Sep 17 00:00:00 2001 From: anujj Date: Fri, 23 May 2025 23:57:50 +0530 Subject: [PATCH 32/57] Switch the TRT optimization profile if multi-profile is enable (#24805) - Switch the multiple profile if multi profile is enabled. - Pass the profile index via OnRunStart option @ankan-ban @gaugarg-nv @chilo-ms --- include/onnxruntime/core/common/common.h | 2 +- .../nv_tensorrt_rtx/nv_provider_options.h | 4 ++++ onnxruntime/core/framework/config_options.h | 2 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 15 ++++++++++++++- .../nv_tensorrt_rtx/nv_execution_provider.h | 2 ++ .../nv_tensorrt_rtx/nv_execution_provider_info.cc | 1 + .../nv_tensorrt_rtx/nv_execution_provider_info.h | 1 + 7 files changed, 24 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 10f658f52e0d9..adfd341451aed 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -302,7 +302,7 @@ inline std::wstring ToWideString(const std::wstring& s) { return s; } inline std::string ToWideString(const std::string& s) { return s; } #endif -constexpr size_t kMaxStrLen = 2048; +constexpr size_t kMaxStrLen = 4096; // Returns whether `key` is in `container`. // Like C++20's map/set contains() member function. diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 0b1cbe6afac79..0c9095f566fad 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -32,7 +32,11 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; constexpr const char* kONNXBytestream = "nv_onnx_bytestream"; constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size"; +constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; } // namespace provider_option_names +namespace run_option_names { +constexpr const char* kProfileIndex = "nv_profile_index"; +} } // namespace nv } // namespace onnxruntime diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 028220d15fc8a..1c356d8cfca56 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -18,7 +18,7 @@ struct ConfigOptions { // Maximum key/value string lengths specified in // core/session/onnxruntime_session_options_config_keys.h static constexpr size_t kMaxKeyLength = 1024; - static constexpr size_t kMaxValueLength = 2048; + static constexpr size_t kMaxValueLength = 4096; std::unordered_map configurations; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 8b864c42714cb..696bb3edb9b85 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include "core/providers/shared_library/provider_api.h" +#include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/common.h" @@ -20,6 +21,7 @@ #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" +#include "core/common/parse_string.h" #include #include #include @@ -1140,6 +1142,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) } cuda_graph_enable_ = info.cuda_graph_enable; + multi_profile_enable_ = info.multi_profile_enable; op_types_to_exclude_ = info.op_types_to_exclude; // Validate setting @@ -1321,7 +1324,12 @@ std::unique_ptr NvExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + if (multi_profile_enable_ == true) { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); + TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); + } return Status::OK(); } @@ -2683,6 +2691,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); + if (multi_profile_enable_ == true) { + if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed"); + } + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity // Prepare cache name diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 76044b4fc2017..35315bdc7d908 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -261,8 +261,10 @@ class NvExecutionProvider : public IExecutionProvider { int (*engine_encryption_)(const char*, char*, size_t) = nullptr; bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; + bool multi_profile_enable_ = false; std::string cache_prefix_; std::string op_types_to_exclude_; + int nv_profile_index_ = 0; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f5ba66746c3c4..444fe1025e393 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -46,6 +46,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) .AddValueParser( nv::provider_option_names::kONNXBytestream, [&onnx_bytestream](const std::string& value_str) -> Status { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 626039e5ef7c8..36addd0a1ce27 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -42,6 +42,7 @@ struct NvExecutionProviderInfo { std::string profile_max_shapes{""}; std::string profile_opt_shapes{""}; bool cuda_graph_enable{false}; + bool multi_profile_enable{false}; bool dump_ep_context_model{false}; std::string ep_context_file_path{""}; int ep_context_embed_mode{0}; From 2bdb57bb0a02316e8eb2a5bad03d91711bd79ff2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 23 May 2025 12:26:54 -0700 Subject: [PATCH 33/57] Update MatMulBNits spec and Add Input Checks (#24828) ### Description Major changes of spec: * 2D scale shape: [N * n_blocks_per_col] => [N, n_blocks_per_col] * 2D zero shape: [N * CeilDiv(n_blocks_per_col * bits, 8)] => [N, CeilDiv(n_blocks_per_col * bits, 8)] * For B, drop int32 type and only allow uint8. * allow bfloat16 as input/output type. * Mark input g_idx as deprecated (since it has no benefit on model size and performance in inference). Add a function CheckInputs to verify the input shape. The reason of the shape change is to make scale and zero compatible with other operators like DequantizeLinear and GatherBlockQuantized. That will make it easy for graph fusion and model builder. Note that ORT can still handle the legacy 1D format for scale and zero points, and CUDA/CPU could still handle g_idx. However, they are deprecated, and our tools shall generate 2D scale and zeros, and avoid using g_idx going forward. This change is backward compatible. Model from old spec can run in latest ORT (CheckInputs handles 1D scale and zero points), and model from latest spec can still run in older ORT (since older ORT does not check dimension of scale and zero points) ### Motivation and Context CUDA and CPU provider does not check inputs for MatMulNBits. It could cause out of boundary access. We are going to share the lm_head weights of MatMulNBits to GatherBlockQuantized. 2D shape can be used in Gather directly, and we can avoid Reshape nodes. Our latest models published for foundry use 2D scale and zero points. So I update the spec to reflect that. --- docs/ContribOperators.md | 75 ++++++++---------- docs/OperatorKernels.md | 2 +- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 5 +- js/web/test/data/ops/matmulnbits.jsonc | 66 ++++++++-------- .../cpu/quantization/matmul_nbits.cc | 36 +++++---- .../cpu/quantization/matmul_nbits_helper.h | 77 ++++++++++++++++++ .../cuda/quantization/matmul_nbits.cc | 11 ++- .../webgpu/quantization/matmul_nbits.cc | 2 +- .../core/graph/contrib_ops/contrib_defs.cc | 79 ++++++++++--------- onnxruntime/core/util/shape_checker.h | 63 +++++++++++++++ .../quantization/matmul_nbits_quantizer.py | 4 +- .../test/contrib_ops/matmul_4bits_test.cc | 64 +++++++-------- .../test/contrib_ops/matmul_8bits_test.cc | 9 ++- 13 files changed, 320 insertions(+), 173 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h create mode 100644 onnxruntime/core/util/shape_checker.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b29fe7adb0da4..7ba2f820e9bdb 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2037,9 +2037,9 @@ This version of the operator has been available since version 1 of the 'com.micr GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (https://github.com/onnx/onnx/blob/main/docs/Operators.md#gather) with differences: 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. - `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. + `block_size` must be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ... 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. + If `zero_points` is not provided, the default value is 0 for int4/uint4, or 2^(bits-1) for uint8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. @@ -2946,29 +2946,20 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MatMulNBits** - MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: - 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. - 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. - And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. - 3. Input B's scale and zero point are specified by input scales and zero_points. - - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) - For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. - - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. - 4bit example: - |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) - - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. - 3bit example: - |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. - The last uint_8 may have some bits unused. - - - Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. - - [N * CeilDiv(n_blocks_per_col * bits, 8)] - If zero_points has same type as A, it's not packed and has the same shape as Scales. + MatMulNBits performs a matrix multiplication where the right-hand-side matrix (weights) is quantized to N bits. + + It is a fusion of two operations: + 1. Linear dequantization of the quantized weights using scale and (optionally) zero-point with formula: + dequantized_weight = (quantized_weight - zero_point) * scale + 2. Matrix multiplication between the input matrix A and the dequantized weight matrix. + + The weight matrix is a 2D constant matrix with the input feature count and output feature count specified by attributes 'K' and 'N'. + It is quantized block-wise along the K dimension with a block size specified by the 'block_size' attribute. + The block size must be a power of 2 and not smaller than 16 (e.g., 16, 32, 64, 128). Each block has its own scale and zero-point. + The quantization is performed using a bit-width specified by the 'bits' attribute, which can take values from 2 to 8. + + The quantized weights are stored in a bit-packed format along the K dimension, with each block being represented by a blob of uint8. + For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a byte, and the second 4 bits are stored in the higher 4 bits of a byte. #### Version @@ -2978,30 +2969,30 @@ This version of the operator has been available since version 1 of the 'com.micr
K : int (required)
-
size of each input feature
+
Input feature dimension of the weight matrix.
N : int (required)
-
size of each output feature
+
Output feature dimension of the weight matrix.
accuracy_level : int
The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
-
bits : int (required)
-
number of bits used for weight quantization (default 4)
+
bits : int
+
Bit-width used to quantize the weights (valid range: 2~8)
block_size : int (required)
-
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
Size of each quantization block along the K (input feature) dimension. Must be a power of two and ≥ 16 (e.g., 16, 32, 64, 128).
#### Inputs (3 - 6)
A : T1
-
The input tensor, not quantized
+
The input tensor, not quantized.
B : T2
-
1 or 2 dimensional data blob
+
Packed uint8 tensor of shape (N, k_blocks, blob_size), where k_blocks = ceil(K / block_size) and blob_size = (block_size * bits / 8). The quantized weights are stored in a bit-packed format along the K dimension, packed within each block_size.
scales : T1
-
quantization scale
+
Per-block scaling factors for dequantization with shape (N, k_blocks) and same data type as input A.
zero_points (optional) : T3
-
quantization zero points
+
Per-block zero point for dequantization. It can be either packed or unpacked: Packed (uint8) format has shape (N, ceil(k_blocks * bits / 8)), and it uses same bit-packing method as Input B. Unpacked (same type as A) format has shape (N, k_blocks). If not provided, a default zero point is used: 2^(bits - 1) (e.g., 8 for 4-bit quantization, 128 for 8-bit).
g_idx (optional) : T4
-
group_idx
+
group_idx. This input is deprecated
bias (optional) : T1
Bias to add to result. It should have shape [N].
@@ -3016,12 +3007,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8), tensor(int32)
-
Constrain quantized weight types to uint8/int32.
-
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
-
Constrain quantized zero point types to uint8/int32/float16/float.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
T3 : tensor(uint8), tensor(float16), tensor(float), tensor(bfloat16)
+
Constrain quantized zero point types to uint8 or float tensors.
T4 : tensor(int32)
the index tensor.
@@ -6354,5 +6345,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 86b490f8f4c43..8c1ab002bce67 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -943,7 +943,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 3e1f1be22efa2..c8e77d14117bf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -48,8 +48,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt if (inputs.length === 4) { const zeroPoints = inputs[3]; const zeroPointsShape = zeroPoints.dims; + + // This assumes zero points are packed. + // Unpack format (zero point has same data type and shape as scale) is not supported by webgpu. const expectedZeroPointsSize = - attributes.bits > 4 ? attributes.n * nBlocksPerCol : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + attributes.n * (attributes.bits === 8 ? nBlocksPerCol : Math.floor((nBlocksPerCol * attributes.bits + 7) / 8)); if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { throw new Error('zeroPoints input size error.'); } diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc index 63e0a0ed52879..f6671fdab7089 100644 --- a/js/web/test/data/ops/matmulnbits.jsonc +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -35,7 +35,7 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -92,12 +92,12 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] }, { - "dims": [8], + "dims": [8, 1], "type": "uint8", "data": [248, 249, 250, 251, 252, 253, 254, 255] } @@ -163,7 +163,7 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] } @@ -229,12 +229,12 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [8], + "dims": [8, 1], "type": "uint8", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -309,7 +309,7 @@ ] }, { - "dims": [24], + "dims": [8, 3], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] } @@ -384,12 +384,12 @@ ] }, { - "dims": [24], + "dims": [8, 3], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] }, { - "dims": [16], + "dims": [8, 2], "type": "uint8", "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] } @@ -474,7 +474,7 @@ ] }, { - "dims": [32], + "dims": [8, 4], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -562,7 +562,7 @@ ] }, { - "dims": [32], + "dims": [8, 4], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -570,7 +570,7 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "uint8", "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] } @@ -604,7 +604,7 @@ ], "cases": [ { - "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric", + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric; 1D scale and zero point", "inputs": [ { "data": [ @@ -742,7 +742,7 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] } @@ -822,12 +822,12 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] } @@ -925,7 +925,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1084,7 +1084,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1092,7 +1092,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -1251,7 +1251,7 @@ ] }, { - "dims": [32], + "dims": [16, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1353,7 +1353,7 @@ ] }, { - "dims": [32], + "dims": [16, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1361,7 +1361,7 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] } @@ -1496,7 +1496,7 @@ ] }, { - "dims": [64], + "dims": [32, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1701,7 +1701,7 @@ ] }, { - "dims": [64], + "dims": [32, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1710,7 +1710,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -1912,7 +1912,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -2112,7 +2112,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -2120,7 +2120,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -2259,7 +2259,7 @@ ] }, { - "dims": [1, 8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -2322,7 +2322,7 @@ ] }, { - "dims": [1, 8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -2386,12 +2386,12 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] } @@ -2458,7 +2458,7 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 8f013a1426ef8..65e8808190da3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" namespace onnxruntime { namespace contrib { @@ -677,11 +678,17 @@ template Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); + // If B is prepacked, B would have been removed from the context + const bool is_b_prepacked = packed_b_size_ > 0; + const Tensor* b = is_b_prepacked ? nullptr : ctx->Input(InputIndex::B); const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input(InputIndex::scales); const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); const Tensor* bias = ctx->Input(InputIndex::bias); + ORT_RETURN_IF_ERROR(matmul_nbits_helper::CheckInputs( + a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); + TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -713,25 +720,22 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } } - // If B is prepacked, B would have been removed from the context - const Tensor* b = ctx->Input(InputIndex::B); return ComputeBUnpacked(a, b, scales, zero_points, reorder_idx, bias, y, allocator, thread_pool, helper); } -#define REGISTER_MatMulNBits(T1) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MatMulNBits, \ - kMSDomain, \ - 1, \ - T1, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("T4", DataTypeImpl::GetTensorType()), \ +#define REGISTER_MatMulNBits(T1) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MatMulNBits, \ + kMSDomain, \ + 1, \ + T1, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), \ MatMulNBits); REGISTER_MatMulNBits(float); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h new file mode 100644 index 0000000000000..80a360ebb1b29 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace contrib { +namespace matmul_nbits_helper { + +template +Status CheckInputs(const T* /*activation*/, + const T* quantized_weight, + const T* scales, + const T* zero_points, + const T* group_index, + const T* bias, + int64_t n, + int64_t k, + int64_t block_size, + int64_t bits) { + // activation (A) + // quantized_weight (B) : (N, k_blocks, blob_size), or null after prepacking. + // k_blocks = (K + block_size - 1) / block_size + // blob_size = block_size * bits / 8 + // scales : (N, k_blocks) + // zero_points : (N, (k_blocks * bits + 7) / 8) for uint8, (N, k_blocks) for other types, or null + // group_index : (K) or (k_blocks * block_size), or null + // bias : (N), or null + // Note that scales and zero_points can be 1D for backward compatibility. + if (bits != 4 && bits != 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bits should be 4 or 8, got ", bits); + } + + if (block_size < 16 || (block_size & (block_size - 1)) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_size must be a power of 2, and >= 16. Got ", block_size); + } + + int64_t k_blocks = (k + block_size - 1) / block_size; + int64_t blob_size = block_size * bits / 8; + + ASSERT_TENSOR_SHAPE(quantized_weight, make_shape(n, k_blocks, blob_size)); + + // 1D shape is for backward compatibility for existing models. + ASSERT_TENSOR_SHAPE_2(scales, make_shape(n * k_blocks), make_shape(n, k_blocks)); + + if (zero_points != nullptr) { + if (zero_points->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + const int64_t zero_point_blob_size = (k_blocks * bits + 7) / 8; + + ASSERT_TENSOR_SHAPE_2(zero_points, make_shape(n * zero_point_blob_size), make_shape(n, zero_point_blob_size)); + } else { + if (zero_points->GetElementType() != scales->GetElementType()) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'zero_points' and 'scales' should have the same data type when zero_points is not uint8"); + } + + ASSERT_TENSOR_SHAPE_2(zero_points, make_shape(n * k_blocks), make_shape(n, k_blocks)); + } + } + + // Group_index shall be 1D of K, or K padded to multiple of block_size + ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size)); + + ASSERT_TENSOR_SHAPE(bias, make_shape(n)); + + return Status::OK(); +} + +} // namespace matmul_nbits_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 33265744f3a7d..ed6021530018f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -9,6 +9,7 @@ #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -25,10 +26,14 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); const Tensor* bias = ctx->Input(5); + if (bias != nullptr) { ORT_THROW("MatMulNBits does not support bias in CUDA kernel"); } + ORT_RETURN_IF_ERROR(matmul_nbits_helper::CheckInputs( + a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); + const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); @@ -207,7 +212,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMulNBits); ONNX_OPERATOR_TYPED_KERNEL_EX( @@ -218,7 +224,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMulNBits); } // namespace cuda diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 65ecdff44acd6..c384b216f049a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -49,7 +49,6 @@ fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_ ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; ss << R"( fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { - // The default zero point is 8. return output_element_t(default_zero_point); } )"; @@ -433,6 +432,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context } // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. + // For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2. uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; // WideTileProgram diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f9f7be60a9bd6..96a1ad91a7f17 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3428,39 +3428,33 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, except t }); static const char* MatMulNBits_ver1_doc = R"DOC( -MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: - 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. - 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. - And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. - 3. Input B's scale and zero point are specified by input scales and zero_points. - - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) - For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. - - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. - 4bit example: - |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) - - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. - 3bit example: - |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. - The last uint_8 may have some bits unused. - - -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. - - [N * CeilDiv(n_blocks_per_col * bits, 8)] - If zero_points has same type as A, it's not packed and has the same shape as Scales. +MatMulNBits performs a matrix multiplication where the right-hand-side matrix (weights) is quantized to N bits. + +It is a fusion of two operations: +1. Linear dequantization of the quantized weights using scale and (optionally) zero-point with formula: + dequantized_weight = (quantized_weight - zero_point) * scale +2. Matrix multiplication between the input matrix A and the dequantized weight matrix. + +The weight matrix is a 2D constant matrix with the input feature count and output feature count specified by attributes 'K' and 'N'. +It is quantized block-wise along the K dimension with a block size specified by the 'block_size' attribute. +The block size must be a power of 2 and not smaller than 16 (e.g., 16, 32, 64, 128). Each block has its own scale and zero-point. +The quantization is performed using a bit-width specified by the 'bits' attribute, which can take values from 2 to 8. + +The quantized weights are stored in a bit-packed format along the K dimension, with each block being represented by a blob of uint8. +For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a byte, and the second 4 bits are stored in the higher 4 bits of a byte. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) .SetDomain(kMSDomain) .SinceVersion(1) .SetDoc(MatMulNBits_ver1_doc) - .Attr("K", "size of each input feature", AttributeProto::INT) - .Attr("N", "size of each output feature", AttributeProto::INT) - .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) - .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("K", "Input feature dimension of the weight matrix.", AttributeProto::INT) + .Attr("N", "Output feature dimension of the weight matrix.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize the weights (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K (input feature) dimension. " + "Must be a power of two and ≥ 16 (e.g., 16, 32, 64, 128).", + AttributeProto::INT) .Attr("accuracy_level", "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " "(default unset). It is used to control how input A is quantized or downcast internally while " @@ -3468,16 +3462,27 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi "computation. 4 means input A can be quantized with the same block_size to int8 internally from " "type T1.", AttributeProto::INT, static_cast(0)) - .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1 or 2 dimensional data blob", "T2") - .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) - .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) + .Input(0, "A", "The input tensor, not quantized.", "T1") + .Input(1, "B", + "Packed uint8 tensor of shape (N, k_blocks, blob_size), " + "where k_blocks = ceil(K / block_size) and blob_size = (block_size * bits / 8). " + "The quantized weights are stored in a bit-packed format along the K dimension, packed within each block_size.", + "T2") + .Input(2, "scales", "Per-block scaling factors for dequantization with shape (N, k_blocks) and same data type as input A.", "T1") + .Input(3, "zero_points", + "Per-block zero point for dequantization. It can be either packed or unpacked: " + "Packed (uint8) format has shape (N, ceil(k_blocks * bits / 8)), and it uses same bit-packing method as Input B. " + "Unpacked (same type as A) format has shape (N, k_blocks). " + "If not provided, a default zero point is used: 2^(bits - 1) (e.g., 8 for 4-bit quantization, 128 for 8-bit). ", + "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx. This input is deprecated", "T4", OpSchema::Optional) .Input(5, "bias", "Bias to add to result. It should have shape [N].", "T1", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") - .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(float16)", "tensor(float)", "tensor(bfloat16)"}, + "Constrain quantized zero point types to uint8 or float tensors.") .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -3569,9 +3574,9 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 static const char* GatherBlockQuantized_ver1_doc = R"DOC( GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (https://github.com/onnx/onnx/blob/main/docs/Operators.md#gather) with differences: 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. - `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. + `block_size` must be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ... 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. + If `zero_points` is not provided, the default value is 0 for int4/uint4, or 2^(bits-1) for uint8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. diff --git a/onnxruntime/core/util/shape_checker.h b/onnxruntime/core/util/shape_checker.h new file mode 100644 index 0000000000000..9c975275c45b9 --- /dev/null +++ b/onnxruntime/core/util/shape_checker.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +template +TensorShape make_shape(Args... args) { + std::initializer_list dims = {args...}; + return TensorShape(dims); +} + +// This assumes the tensor is optional, and check wether its shape is expected. +#define ASSERT_TENSOR_DIMS(tensor, ...) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (tensor_shape != expected_shape) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", expected_shape, \ + ", got ", tensor_shape); \ + } \ + } + +// This assumes the tensor is optional, and check wether its shape is expected. +#define ASSERT_TENSOR_SHAPE(tensor, shape) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape must be or refer to a TensorShape"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + if (tensor_shape != shape) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", shape, \ + ", got ", tensor_shape); \ + } \ + } + +// This assumes the tensor is optional, and check wether its shape is shape_1 or shape_2 when it is not null. +#define ASSERT_TENSOR_SHAPE_2(tensor, shape_1, shape_2) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape_1 must be or refer to a TensorShape"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape_2 must be or refer to a TensorShape"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + if (tensor_shape != shape_1 && tensor_shape != shape_2) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", shape_1, \ + " or ", shape_2, ", got ", tensor_shape); \ + } \ + } + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index 0297472d0738c..9de11041f5331 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -866,7 +866,9 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis kwargs["N"] = cols kwargs["bits"] = bits kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: + + # Do not output accuracy_level if it is 0 since the attribute is optional and is not supported by most EPs. + if self.config.accuracy_level: kwargs["accuracy_level"] = self.config.accuracy_level matmul_qbit_node = onnx.helper.make_node( diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bb2bfab585da8..043f7ed57b8b0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -81,6 +81,8 @@ struct TestOptions { bool has_g_idx{false}; bool has_bias{false}; + bool legacy_shape{false}; // for backward compatibility + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -107,28 +109,20 @@ void RunTest(const TestOptions& opts, const bool zp_is_4bit = opts.zp_is_4bit || opts.has_g_idx; - const int64_t M = opts.M, - K = opts.K, - N = opts.N; + const int64_t M = opts.M; + const int64_t K = opts.K; + const int64_t N = opts.N; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); -#if 0 // for Debugging - std::vector input1_f_vals_trans(N * K); - MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); -#endif - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_rows, q_cols); - - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; + int64_t blob_size = (opts.block_size * QBits + 7) / 8; + size_t q_scale_size = static_cast(N * k_blocks); + size_t q_data_size_in_bytes = static_cast(N * k_blocks * blob_size); // packed as UInt4x2 + const int64_t zero_point_blob_size = (k_blocks * QBits + 7) / 8; + size_t q_zp_size_in_bytes = static_cast(N * zero_point_blob_size); // packed as UInt4x2 std::vector input1_vals(q_data_size_in_bytes); std::vector scales(q_scale_size); @@ -142,16 +136,6 @@ void RunTest(const TestOptions& opts, static_cast(K), static_cast(opts.block_size)); -#if 0 - for (int i = 0; i < input1_vals.size(); i++) - { - uint8_t byte = input1_vals[i]; - uint8_t val_lo = byte & 0x0f; - uint8_t val_hi = byte >> 4; - std::cout << (int)val_lo << ", " << (int)val_hi << ", "; - } -#endif - const std::vector bias_shape = {N}; const auto bias = [&]() -> std::optional> { if (opts.has_bias) { @@ -184,17 +168,22 @@ void RunTest(const TestOptions& opts, test.AddInput("A", {M, K}, input0_vals, false); } - test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("B", {N, k_blocks, blob_size}, input1_vals, true); + + auto scales_shape = opts.legacy_shape ? std::vector{N * k_blocks} + : std::vector{N, k_blocks}; if constexpr (use_float16) { - test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); + test.AddInput("scales", scales_shape, ToFloat16(scales), true); } else { - test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + test.AddInput("scales", scales_shape, scales, true); } if (opts.has_zero_point) { if (zp_is_4bit) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + auto zp_shape = opts.legacy_shape ? std::vector{N * zero_point_blob_size} + : std::vector{N, zero_point_blob_size}; + test.AddInput("zero_points", zp_shape, zp, true); } else { std::vector zp_f; zp_f.reserve(q_zp_size_in_bytes * 2); @@ -209,9 +198,9 @@ void RunTest(const TestOptions& opts, } if constexpr (use_float16) { - test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + test.AddInput("zero_points", scales_shape, ToFloat16(zp_f), true); } else { - test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + test.AddInput("zero_points", scales_shape, zp_f, true); } } } else { @@ -267,7 +256,7 @@ void RunTest(const TestOptions& opts, } // namespace -template +template void TestMatMulNBitsTyped() { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -464,6 +453,13 @@ TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); } + +TEST(MatMulNBits, LegacyShape) { + constexpr bool legacy_shape = true; + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + #endif #endif #endif diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 257d3b3efdf9c..63677094b1b4b 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -143,16 +143,17 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_fp32_vals), false); } - test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; + test.AddInput("B", {q_cols, k_blocks, q_rows / k_blocks}, input1_vals, true); if constexpr (std::is_same::value) { - test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, scales, true); } else { - test.AddInput("scales", {static_cast(q_scale_size)}, FloatsToMLFloat16s(scales), true); + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, FloatsToMLFloat16s(scales), true); } if (opts.has_zero_point) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + test.AddInput("zero_points", {N, static_cast(q_zp_size_in_bytes) / N}, zp, true); } else { test.AddOptionalInputEdge(); } From 9dad9af9f9b48c05814d0c2d067d0565e8da6ce8 Mon Sep 17 00:00:00 2001 From: Kevin Chen <45886021+kevinch-nv@users.noreply.github.com> Date: Fri, 23 May 2025 15:57:08 -0700 Subject: [PATCH 34/57] [TRT EP] Update build and API usage for TensorRT 10.11 (#24832) ### Description Resolves the following issues starting in TensorRT 10.11: - Version macros changed in `NvInferVersion.h`, update build to look for new macros - Updated deprecated APIs (setShapeValues -> setShapeValuesV2() to support INT64 shape values) ### Motivation and Context - Resolves building TensorRT EP from source with latest 10.11 release. Signed-off-by: Kevin Chen --- cmake/onnxruntime_providers_tensorrt.cmake | 27 ++++++++++++---- .../tensorrt/tensorrt_execution_provider.cc | 31 ++++++++++++++++++- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 59c7db9999b43..3698aaa902922 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -33,12 +33,27 @@ PATH_SUFFIXES include) file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT) - string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") - string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") - string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + + # Starting TRT 10.11, TRT version macros have changed + string(REGEX MATCH "TRT_MAJOR_ENTERPRISE" TRT_VER_CHECK "${NVINFER_VER_CONTENT}") + # Pre TRT 10.11 + if("${TRT_VER_CHECK}" STREQUAL "") + string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") + string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") + string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + # TRT 10.11+ + else() + string(REGEX MATCH "define TRT_MAJOR_ENTERPRISE * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MAJOR_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") + string(REGEX MATCH "define TRT_MINOR_ENTERPRISE * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MINOR_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") + string(REGEX MATCH "define TRT_PATCH_ENTERPRISE * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_PATCH_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + endif() + math(EXPR NV_TENSORRT_MAJOR_INT "${NV_TENSORRT_MAJOR}") math(EXPR NV_TENSORRT_MINOR_INT "${NV_TENSORRT_MINOR}") math(EXPR NV_TENSORRT_PATCH_INT "${NV_TENSORRT_PATCH}") diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c30b862395e96..f1fa59bf03495 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -538,9 +538,18 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(j)][i].push_back(opt_value); } +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min_64(shapes_min.begin(), shapes_min.end()); + std::vector shapes_opt_64(shapes_opt.begin(), shapes_opt.end()); + std::vector shapes_max_64(shapes_max.begin(), shapes_max.end()); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_opt_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_max_64[0], shape_size); +#else trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } // Execution tensor else { @@ -627,6 +636,17 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorisShapeTensor()) { // shape tensor int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + for (int j = 0; j < shape_size; j++) { + shapes_min[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); + shapes_max[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); + shapes_opt[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); + } + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#else std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); for (int j = 0; j < shape_size; j++) { shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); @@ -636,6 +656,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorsetShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } else { // execution tensor nvinfer1::Dims dims_min, dims_opt, dims_max; @@ -733,10 +754,18 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min_64(shapes_min.begin(), shapes_min.end()); + std::vector shapes_opt_64(shapes_opt.begin(), shapes_opt.end()); + std::vector shapes_max_64(shapes_max.begin(), shapes_max.end()); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_opt_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_max_64[0], shape_size); +#else trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } else { // Execution tensor nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); for (int j = 0, end = nb_dims; j < end; ++j) { From 3a2091075afaff3425788e36c22bec43cf59ed4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Sat, 24 May 2025 01:15:08 +0200 Subject: [PATCH 35/57] [NvTensorRT RTX] Add Bfloat16 (#24743) ### Description TRT supports Bfloat 16 and ORT does as well. In addition the `setup.py` was missing a copy for NVTRT EP and TRT EP can only be built against the packaged parser with TRT RTX. --- cmake/onnxruntime_providers_nv.cmake | 4 ++ .../tensorrt/tensorrt_provider_options.h | 1 + .../providers/nv_tensorrt_rtx/nv_allocator.cc | 1 + .../nv_tensorrt_rtx/nv_execution_provider.cc | 3 + .../nv_tensorrt_rtx/nv_execution_provider.h | 1 + .../nv_execution_provider_custom_ops.cc | 1 + .../nv_execution_provider_custom_ops.h | 1 + .../nv_execution_provider_helper.cc | 1 + .../nv_execution_provider_info.cc | 1 + .../nv_execution_provider_info.h | 1 + .../nv_execution_provider_utils.h | 1 + .../providers/nv_tensorrt_rtx/nv_includes.h | 1 + .../nv_tensorrt_rtx/nv_provider_factory.cc | 1 + .../nv_tensorrt_rtx/nv_provider_factory.h | 1 + .../nv_provider_factory_creator.h | 1 + .../nv_tensorrt_rtx/onnx_ctx_model_helper.cc | 1 + .../nv_tensorrt_rtx/onnx_ctx_model_helper.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 54 ++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 3 + .../tensorrt_execution_provider_info.cc | 5 ++ .../tensorrt_execution_provider_info.h | 1 + .../tensorrt/tensorrt_provider_factory.cc | 1 + .../python/onnxruntime_pybind_state.cc | 8 +++ .../gen_trt_engine_wrapper_onnx_model.py | 4 ++ .../tools/transformers/io_binding_helper.py | 1 + .../nv_tensorrt_rtx/nv_basic_test.cc | 64 +++++++++++++++++-- onnxruntime/test/shared_lib/test_inference.cc | 2 +- setup.py | 2 + 28 files changed, 144 insertions(+), 23 deletions(-) diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index 12d824fc3360e..a804f2d7ae55c 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Licensed under the MIT License. find_package(CUDAToolkit REQUIRED 12.8) enable_language(CUDA) @@ -9,6 +10,9 @@ if (onnxruntime_NV_PLACEHOLDER_BUILDER) add_definitions(-DORT_NV_PLACEHOLDER_BUILDER) endif() +if (NOT onnxruntime_USE_TENSORRT_BUILTIN_PARSER) + message(FATAL_ERROR "TensorRT RTX can not be used with the open source parser.") +endif () set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 687f74c94f154..9fb1eb9107774 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -21,6 +21,7 @@ struct OrtTensorRTProviderOptionsV2 { int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs size_t trt_max_workspace_size{0}; // maximum workspace size for TensorRT. Default is 0 means max device memory size int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true + int trt_bf16_enable{0}; // enable TensorRT BF16 precision. Default 0 = false, nonzero = true int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name. int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc index 4e8179d86fd73..a44ab93ccca8b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "nv_allocator.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 696bb3edb9b85..0fb44fe4eda85 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -745,6 +745,7 @@ Status BindContextInput(Ort::KernelContext& ctx, switch (tensor_type) { CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -831,6 +832,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -894,6 +896,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 35315bdc7d908..6c5e1a1f0a8d3 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 0806ae3638036..c8df7c9437adf 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h index 897c2ce0e0b98..81c0d49239ec8 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc index cd50f1e6b2d48..8728558006fc5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index 444fe1025e393..78f2723a20118 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 36addd0a1ce27..e70e70bf05eb9 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index 046010deedf62..22e5eea6924de 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h index 047f325f49b70..a4e3777008560 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 1f4eed7db7203..0fc3e5443bc28 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h index 928874475735f..5672c5dda632e 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "onnxruntime_c_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h index 616f5f1fbe754..6b2e516211257 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 25decd8f2ce8f..21d964b0c341f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index ccd06750692fc..f0a05c42414e5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f1fa59bf03495..5da7be0f758e0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -987,6 +987,7 @@ Status BindContextInput(Ort::KernelContext& ctx, switch (tensor_type) { CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1079,6 +1080,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1148,6 +1150,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1365,6 +1368,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv min_subgraph_size_ = info.min_subgraph_size; max_workspace_size_ = info.max_workspace_size; fp16_enable_ = info.fp16_enable; + bf16_enable_ = info.bf16_enable; int8_enable_ = info.int8_enable; if (int8_enable_) { int8_calibration_cache_name_ = info.int8_calibration_table_name; @@ -1411,7 +1415,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } force_sequential_engine_build_ = info.force_sequential_engine_build; context_memory_sharing_enable_ = info.context_memory_sharing_enable; - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { layer_norm_fp32_fallback_ = info.layer_norm_fp32_fallback; } build_heuristics_enable_ = info.build_heuristics_enable; @@ -1448,6 +1452,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBF16Enable); + if (!bf16_enable_env.empty()) { + bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true); + } + const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); @@ -1789,6 +1798,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_min_subgraph_size: " << min_subgraph_size_ << ", trt_max_workspace_size: " << max_workspace_size_ << ", trt_fp16_enable: " << fp16_enable_ + << ", trt_bf16_enable: " << bf16_enable_ << ", trt_int8_enable: " << int8_enable_ << ", trt_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ @@ -2328,7 +2338,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); #endif @@ -2941,7 +2951,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); #endif @@ -2954,7 +2964,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow - if (fp16_enable_ && layer_norm_fp32_fallback_) { + if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); auto next_layer = trt_network->getLayer(idx + 1); @@ -3103,7 +3113,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Check platform availability for low precision - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -3113,7 +3123,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #pragma warning(pop) #endif fp16_enable_ = false; - LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + bf16_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; } } @@ -3142,15 +3153,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Set precision flags std::string trt_node_name_with_precision = fused_node.Name(); - if (fp16_enable_ && int8_enable_) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision += "_fp16_int8"; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; - } else if (fp16_enable_) { + if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); trt_node_name_with_precision += "_fp16"; LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; - } else if (int8_enable_) { + } + if (bf16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + trt_node_name_with_precision += "_bf16"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; + } + if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; @@ -3570,7 +3583,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, bf16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, &context_memory_, dynamic_range_map, engine_decryption_enable_, @@ -3772,12 +3785,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { + if (trt_state->int8_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; } // Set DLA (DLA can only run with FP16 or INT8) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index d2e8febea2339..b00c800999f3b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -23,6 +23,7 @@ static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_I static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; +static const std::string kBF16Enable = "ORT_TENSORRT_BF16_ENABLE"; static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE"; static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"; static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE"; @@ -172,6 +173,7 @@ struct TensorrtFuncState { std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; + bool bf16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; bool dla_enable = false; @@ -297,6 +299,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { size_t min_subgraph_size_ = 1; size_t max_workspace_size_ = 0; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool int8_enable_ = false; bool dla_enable_ = false; int dla_core_ = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index ace5bbe65fc24..1a515c37f7ecb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -19,6 +19,7 @@ constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kBf16Enable = "trt_bf16_enable"; constexpr const char* kInt8Enable = "trt_int8_enable"; constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; @@ -93,6 +94,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) @@ -155,6 +157,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {tensorrt::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.bf16_enable)}, {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, @@ -222,6 +225,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, + {tensorrt::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.trt_bf16_enable)}, {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, @@ -319,6 +323,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size; trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; + trt_provider_options_v2.trt_bf16_enable = internal_options.bf16_enable; trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 139319829c210..a7c3624674dc6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -24,6 +24,7 @@ struct TensorrtExecutionProviderInfo { int min_subgraph_size{1}; size_t max_workspace_size{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 0d2e88d17519c..da1c2514bf6a2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -82,6 +82,7 @@ struct Tensorrt_Provider : Provider { info.min_subgraph_size = options.trt_min_subgraph_size; info.max_workspace_size = options.trt_max_workspace_size; info.fp16_enable = options.trt_fp16_enable != 0; + info.bf16_enable = options.trt_bf16_enable != 0; info.int8_enable = options.trt_int8_enable != 0; info.int8_calibration_table_name = options.trt_int8_calibration_table_name == nullptr ? "" : options.trt_int8_calibration_table_name; info.int8_use_native_calibration_table = options.trt_int8_use_native_calibration_table != 0; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 12a44e65e247b..5c389a85e5316 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -624,6 +624,14 @@ static std::shared_ptr CreateExecutionProviderFactory } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_bf16_enable") { + if (option.second == "True" || option.second == "true") { + params.trt_bf16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.trt_bf16_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_bf16_enable' should be 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "trt_int8_enable") { if (option.second == "True" || option.second == "true") { params.trt_int8_enable = true; diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index 1180945d5b5dc..5183ae9a72246 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -114,6 +114,8 @@ def trt_data_type_to_onnx_data_type(self, trt_data_type): return TensorProto.FLOAT elif trt_data_type == trt.DataType.HALF: return TensorProto.FLOAT16 + elif trt_data_type == trt.DataType.BF16: + return TensorProto.BFLOAT16 elif trt_data_type == trt.DataType.INT8: return TensorProto.INT8 elif trt_data_type == trt.DataType.INT32: @@ -122,6 +124,8 @@ def trt_data_type_to_onnx_data_type(self, trt_data_type): return TensorProto.BOOL elif trt_data_type == trt.DataType.UINT8: return TensorProto.UINT8 + elif trt_data_type == trt.DataType.INT64: + return TensorProto.INT64 else: return TensorProto.UNDEFINED diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 2b19ae5029ecc..072bb9bb39a79 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -53,6 +53,7 @@ def ort_type_to_torch_type(ort_type: str): "tensor(int32)": torch.int32, "tensor(float)": torch.float32, "tensor(float16)": torch.float16, + "tensor(bfloat16)": torch.bfloat16, "tensor(bool)": torch.bool, "tensor(uint8)": torch.uint8, } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index f3a963ce47eda..c04cbc7d4924e 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" @@ -24,6 +25,35 @@ namespace onnxruntime { namespace test { +template +class NvExecutionProviderTest : public ::testing::Test { + protected: + std::string getTypeAsName() { + std::string dtype_name = ""; + if constexpr (std::is_same::value) { + dtype_name = "fp64"; + } else if constexpr (std::is_same::value) { + dtype_name = "fp32"; + } else if constexpr (std::is_same::value) { + dtype_name = "fp16"; + } else if constexpr (std::is_same::value) { + dtype_name = "bf16"; + } else if constexpr (std::is_same::value) { + dtype_name = "int8"; + } else if constexpr (std::is_same::value) { + dtype_name = "uint8"; + } else if constexpr (std::is_same::value) { + dtype_name = "int32"; + } else if constexpr (std::is_same::value) { + dtype_name = "int64"; + } + return dtype_name; + } +}; + +using NvExecutionProviderTestTypes = ::testing::Types; // double, +TYPED_TEST_SUITE(NvExecutionProviderTest, NvExecutionProviderTestTypes); + std::string PathToUTF8(const PathString& path) { #ifdef WIN32 std::wstring_convert> converter; @@ -89,7 +119,8 @@ void VerifyOutputs(const std::vector& fetches, const std::vector dims, - bool add_fast_gelu = false) { + bool add_fast_gelu = false, + ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -97,13 +128,13 @@ static void CreateBaseModel(const PathString& model_name, // FLOAT tensor ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->set_elem_type(dtype); for (auto dim : dims) { float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } ONNX_NAMESPACE::TypeProto dyn_float_tensor; - dyn_float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); @@ -139,7 +170,7 @@ static void CreateBaseModel(const PathString& model_name, } ONNX_NAMESPACE::TypeProto float_scalar; - float_scalar.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_scalar.mutable_tensor_type()->set_elem_type(dtype); float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); inputs.push_back(&input_scalar); @@ -331,5 +362,30 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { } } +TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { + std::string dtype_name = this->getTypeAsName(); + ASSERT_FALSE(dtype_name.empty()); + PathString model_name = ORT_TSTR("nv_execution_provider_" + dtype_name + ".onnx"); + std::string graph_name = "test" + dtype_name; + std::vector dims = {1, -1, -1}; + + CreateBaseModel(model_name, graph_name, dims, true); + + auto env = Ort::Env(); + auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + env.UpdateEnvWithCustomLogLevel(logging_level); + + // AOT time + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + Ort::Session session_object(env, model_name.c_str(), so); + + auto io_binding = generate_io_binding(session_object); + session_object.Run(run_options, io_binding); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6460e3cb3aec4..b49c0bad711e1 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3931,7 +3931,7 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { * The TensorrtExecutionProviderOptionsTest can be used to test TRT options */ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, - ::testing::Values("trt_build_heuristics_enable=1", "trt_sparsity_enable=1", "trt_builder_optimization_level=0", "trt_tactic_sources=-CUDNN,+CUBLAS", "trt_auxiliary_streams=2")); + ::testing::Values("trt_build_heuristics_enable=1", "trt_sparsity_enable=1", "trt_builder_optimization_level=0", "trt_tactic_sources=-CUDNN,+CUBLAS", "trt_auxiliary_streams=2", "trt_bf16_enable=1")); #endif #ifdef USE_CUDA diff --git a/setup.py b/setup.py index c45657c0c2873..b30d0e1e6aea4 100644 --- a/setup.py +++ b/setup.py @@ -324,6 +324,7 @@ def finalize_options(self): if platform.system() == "Linux": providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" + providers_nv_tensorrt_rtx = "lib" + providers_nv_tensorrt_rtx + ".so" providers_openvino = "lib" + providers_openvino + ".so" providers_cann = "lib" + providers_cann + ".so" providers_qnn = "lib" + providers_qnn + ".so" @@ -361,6 +362,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_openvino.so"]) libs.extend(["libonnxruntime_providers_vitisai.so"]) libs.append(providers_cuda_or_rocm) + libs.append(providers_nv_tensorrt_rtx) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) libs.append(providers_qnn) From 625289c181ca3c91405d98cb1e755285cedc4bf4 Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Tue, 27 May 2025 09:21:15 -0700 Subject: [PATCH 36/57] [QNN EP] Add ScatterND reduction attribute (#24844) ### Description - Add support for ScatterND reduction attribute - Gracefully handle the unsupported reduction values - Add unit tests to validate Reduction attribute support ### Motivation and Context --- .../builder/opbuilder/simple_op_builder.cc | 39 ++++++++++ .../test/providers/qnn/simple_op_htp_test.cc | 72 +++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index ab022df063c96..2650316dd07ac 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -40,6 +40,7 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; + static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, @@ -101,6 +102,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, } } + // QNN ScatterND doesn't support MAX, MIN reduction + if (op_type == "ScatterND") { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + ORT_RETURN_IF_NOT(utils::ArrayHasString(scatternd_supported_reduction, reduction), "ScatterND does not support reduction ", + reduction.c_str()); + } + return Status::OK(); } @@ -254,6 +263,31 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// Process Reduction attribute of ScatterND op +Status ProcessScatterNDReductionAttribute(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector& param_tensor_names) { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT; + reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + if ("none" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_NONE; + } else if ("add" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_ADD; + } else if ("mul" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_MUL; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterND support only reduction:{none, add, mul}."); + } + QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ND_PARAM_REDUCTION, + reduction_qnn_scalar); + param_tensor_names.push_back(reduction_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(reduction_param)); + + return Status::OK(); +} + Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -358,6 +392,11 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } + if (op_type == "ScatterND") { + // Process reduction attribute + ORT_RETURN_IF_ERROR(ProcessScatterNDReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names)); + } + return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index bfdb1a1a6afdd..b441af4a0efe9 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1017,6 +1017,78 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { ExpectedEPNodeAssignment::All); } +// Test ScatterND with reduction ADD on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_add) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "add"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterND with reduction Mul on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_mul) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "mul"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterND with reduction Max on CPU Fallback +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_max) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "max"), + }, + 17, + ExpectedEPNodeAssignment::None); +} + +// Test ScatterND with reduction Min on CPU Fallback +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_min) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "min"), + }, + 17, + ExpectedEPNodeAssignment::None); +} + // Test 8-bit QDQ GridSample with bilinear TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { RunQDQOpTest("GridSample", From 5618199f8f1c426d9ad34e387ade068d25212f36 Mon Sep 17 00:00:00 2001 From: Wang Ning Date: Wed, 28 May 2025 03:19:00 +0800 Subject: [PATCH 37/57] [WebNN] Refactor op mappings and add input name mapping between ONNX and WebNN (#24830) ### Description Add `map_info.h` to centralize the operation types and inputs mapping between onnx and webnn. ### Motivation and Context To simplify the maintenance of operation types and inputs. The mapping of onnx input names and webnn input names will be used in the future to check the `rankRange`. @honry, @fdwr, @guschmue, PTAL, thanks! --------- Co-authored-by: Wanming Lin --- .../core/providers/webnn/builders/helper.h | 184 ++-------------- .../core/providers/webnn/builders/map_info.h | 205 ++++++++++++++++++ 2 files changed, 219 insertions(+), 170 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/map_info.h diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 072273a137557..f124e90580353 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -9,8 +9,8 @@ #include "core/common/inlined_containers.h" #include #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" +#include "map_info.h" #include #include @@ -201,183 +201,27 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const emscripten::val& wnn_limits, const logging::Logger& logger); -// Some ONNX ops are supported by decomposed WebNN ops. -const std::map> decomposed_op_map = { - {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, - {"GroupQueryAttention", - {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", - "softmax", "transpose", "where"}}, - {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, - {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, - {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, - {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, - {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, - {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, -}; -// ONNX op type to WebNN op type mapping. -const std::map op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"And", "logicalAnd"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "batchNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"CumSum", "cumulativeSum"}, - {"Div", "div"}, - {"DequantizeLinear", "dequantizeLinear"}, - {"Dropout", "identity"}, - {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, - {"Einsum", "matmul"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "reshape"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"GatherElements", "gatherElements"}, - {"GatherND", "gatherND"}, - {"Gelu", "gelu"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GreaterOrEqual", "greaterOrEqual"}, - {"GRU", "gru"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "instanceNormalization"}, - {"LayerNormalization", "layerNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"LessOrEqual", "lesserOrEqual"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"LSTM", "lstm"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Or", "logicalOr"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"QuantizeLinear", "quantizeLinear"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"ScatterElements", "scatterElements"}, - {"ScatterND", "scatterND"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Sign", "sign"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "reshape"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Tile", "tile"}, - {"Transpose", "transpose"}, - {"Trilu", "triangular"}, - {"Unsqueeze", "reshape"}, - {"Where", "where"}, - {"Xor", "logicalXor"}, -}; - -// WebNN op name to its first input name mapping, only record the name that is different from "input". -// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits. -const std::map webnn_op_first_input_name_map = { - {"add", "a"}, - {"concat", "inputs"}, - {"div", "a"}, - {"equal", "a"}, - {"gemm", "a"}, - {"greater", "a"}, - {"greaterOrEqual", "a"}, - {"lesser", "a"}, - {"lesserOrEqual", "a"}, - {"logicalAnd", "a"}, - {"logicalNot", "a"}, - {"logicalOr", "a"}, - {"logicalXor", "a"}, - {"matmul", "a"}, - {"max", "a"}, - {"min", "a"}, - {"mul", "a"}, - {"pow", "a"}, - {"sub", "a"}, - {"where", "condition"}, -}; - // Retrieve the first input name of a WebNN op used for validating supported input data types. // WebNN ops have various first input names such as 'a', 'input', 'inputs', etc. -// Special names other than 'input' are recorded in the webnn_op_first_input_name_map. +// All WebNN op inputs are recorded in op_inputs_map. inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) { - auto it = webnn_op_first_input_name_map.find(webnn_op_type); - return (it != webnn_op_first_input_name_map.end()) ? it->second : "input"; + auto it = op_inputs_map.find(webnn_op_type); + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == 0) { + return input.name; + } + } + } + return "input"; } inline std::string_view GetWebNNOpType(const std::string_view op_type) { - auto it = op_map.find(op_type); - // Return an empty string if the op_type is not listed in the op_map. - return (it != op_map.end()) ? it->second : ""; + auto it = op_inputs_map.find(op_type); + // Return an empty string if the op_type is not listed in the op_inputs_map. + return (it != op_inputs_map.end()) ? it->second.opType : ""; } -const std::map onnx_to_webnn_data_type_map = { - {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, -}; - -// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. -constexpr std::array supported_fallback_integer_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, -}; - bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h new file mode 100644 index 0000000000000..59408ba244842 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/common.h" + +/** + * This file defines mappings and structures to facilitate the translation of ONNX operations + * and data types to their corresponding WebNN representations. + * + * It includes: + * - Data type mappings between ONNX and WebNN. + * - Lists of supported fallback integer types for WebNN. + * - Decomposition of certain ONNX operations into sequences of WebNN operations. + * - Structures and maps for input index-to-name translation for ONNX to WebNN ops. + */ +namespace onnxruntime { +namespace webnn { +const std::map onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, +}; + +// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. +constexpr std::array supported_fallback_integer_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT32, + ONNX_NAMESPACE::TensorProto_DataType_INT64, +}; + +// Some ONNX ops are supported by decomposed WebNN ops. +const std::map> decomposed_op_map = { + {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, + {"GroupQueryAttention", + {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", + "softmax", "transpose", "where"}}, + {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, + {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, + {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, + {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, + {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, + {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, + {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, +}; + +/** + * Represents information about an input to a WebNN operation. + * + * This structure is used to map ONNX operation inputs to their corresponding + * WebNN operation inputs. It contains the index of the input as specified + * in the ONNX operation and the name of the input in the WebNN operation. + * + * InputInfo::index + * The index of this input as defined in the ONNX operation specification. + * + * InputInfo::name + * The name of this input in the WebNN operation. + */ +struct InputInfo { + int index; + std::string_view name; +}; + +struct WebnnOpInfo { + std::string_view opType; + std::vector inputs; + WebnnOpInfo(std::string_view op, std::initializer_list in) + : opType(op), inputs(in) {} +}; + +/** + * Maps ONNX operation type to their corresponding WebNN operation type and input mappings. + * + * This unordered map provides a mapping between ONNX operation names (keys) and their corresponding + * WebNN operation information (values). Each value is a `WebnnOpInfo` structure that contains: + * - The WebNN operation name (`opType`). + * - A vector of `InputInfo` structures, where each `InputInfo` specifies: + * - The index of the input in the ONNX operation (`index`). + * - The corresponding input name in the WebNN operation (`name`). + * + * For the ONNX operation "Abs", it has only one "input", which is at index 0 in the "Node.InputDefs" array. + * The corresponding WebNN operation is "abs", and the input name is "input". + * + * This mapping is used to translate ONNX operations and their inputs into WebNN operations + * and their respective input names. + * + * Order: + * The sorting rule is based on character length in ascending order (for better formatting), + * and for items with the same length, they are sorted alphabetically. + */ +const std::unordered_map op_inputs_map = { + {"Cos", {"cos", {{0, "input"}}}}, + {"Abs", {"abs", {{0, "input"}}}}, + {"Elu", {"elu", {{0, "input"}}}}, + {"Erf", {"erf", {{0, "input"}}}}, + {"Exp", {"exp", {{0, "input"}}}}, + {"Log", {"log", {{0, "input"}}}}, + {"Neg", {"neg", {{0, "input"}}}}, + {"Pad", {"pad", {{0, "input"}}}}, + {"Sin", {"sin", {{0, "input"}}}}, + {"Tan", {"tan", {{0, "input"}}}}, + {"Cast", {"cast", {{0, "input"}}}}, + {"Ceil", {"ceil", {{0, "input"}}}}, + {"Gelu", {"gelu", {{0, "input"}}}}, + {"Relu", {"relu", {{0, "input"}}}}, + {"Sign", {"sign", {{0, "input"}}}}, + {"Sqrt", {"sqrt", {{0, "input"}}}}, + {"Tanh", {"tanh", {{0, "input"}}}}, + {"Tile", {"tile", {{0, "input"}}}}, + {"Clip", {"clamp", {{0, "input"}}}}, + {"Floor", {"floor", {{0, "input"}}}}, + {"Shape", {"slice", {{0, "input"}}}}, + {"Slice", {"slice", {{0, "input"}}}}, + {"Split", {"split", {{0, "input"}}}}, + {"Sub", {"sub", {{0, "a"}, {1, "b"}}}}, + {"Add", {"add", {{0, "a"}, {1, "b"}}}}, + {"ArgMax", {"argMax", {{0, "input"}}}}, + {"ArgMin", {"argMin", {{0, "input"}}}}, + {"Div", {"div", {{0, "a"}, {1, "b"}}}}, + {"Expand", {"expand", {{0, "input"}}}}, + {"Max", {"max", {{0, "a"}, {1, "b"}}}}, + {"Min", {"min", {{0, "a"}, {1, "b"}}}}, + {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, + {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, + {"Concat", {"concat", {{0, "inputs"}}}}, + {"Not", {"logicalNot", {{0, "input"}}}}, + {"Flatten", {"reshape", {{0, "input"}}}}, + {"LpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Reshape", {"reshape", {{0, "input"}}}}, + {"Sigmoid", {"sigmoid", {{0, "input"}}}}, + {"Softmax", {"softmax", {{0, "input"}}}}, + {"Squeeze", {"reshape", {{0, "input"}}}}, + {"Dropout", {"identity", {{0, "input"}}}}, + {"Trilu", {"triangular", {{0, "input"}}}}, + {"Equal", {"equal", {{0, "a"}, {1, "b"}}}}, + {"Identity", {"identity", {{0, "input"}}}}, + {"Less", {"lesser", {{0, "a"}, {1, "b"}}}}, + {"MaxPool", {"maxPool2d", {{0, "input"}}}}, + {"ReduceL1", {"reduceL1", {{0, "input"}}}}, + {"ReduceL2", {"reduceL2", {{0, "input"}}}}, + {"Resize", {"resample2d", {{0, "input"}}}}, + {"Softplus", {"softplus", {{0, "input"}}}}, + {"Softsign", {"softsign", {{0, "input"}}}}, + {"Unsqueeze", {"reshape", {{0, "input"}}}}, + {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, + {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"HardSwish", {"hardSwish", {{0, "input"}}}}, + {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, + {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"ReduceMax", {"reduceMax", {{0, "input"}}}}, + {"ReduceMin", {"reduceMin", {{0, "input"}}}}, + {"ReduceSum", {"reduceSum", {{0, "input"}}}}, + {"Transpose", {"transpose", {{0, "input"}}}}, + {"And", {"logicalAnd", {{0, "a"}, {1, "b"}}}}, + {"CumSum", {"cumulativeSum", {{0, "input"}}}}, + {"Xor", {"logicalXor", {{0, "a"}, {1, "b"}}}}, + {"GlobalLpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Greater", {"greater", {{0, "a"}, {1, "b"}}}}, + {"Reciprocal", {"reciprocal", {{0, "input"}}}}, + {"ReduceMean", {"reduceMean", {{0, "input"}}}}, + {"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}}, + {"HardSigmoid", {"hardSigmoid", {{0, "input"}}}}, + {"ReduceProd", {"reduceProduct", {{0, "input"}}}}, + {"AveragePool", {"averagePool2d", {{0, "input"}}}}, + {"Gemm", {"gemm", {{0, "a"}, {1, "b"}, {2, "c"}}}}, + {"PRelu", {"prelu", {{0, "input"}, {1, "slope"}}}}, + {"ReduceLogSum", {"reduceLogSum", {{0, "input"}}}}, + {"Gather", {"gather", {{0, "input"}, {1, "indices"}}}}, + {"LessOrEqual", {"lesserOrEqual", {{0, "a"}, {1, "b"}}}}, + {"GlobalAveragePool", {"averagePool2d", {{0, "input"}}}}, + {"ReduceLogSumExp", {"reduceLogSumExp", {{0, "input"}}}}, + {"ReduceSumSquare", {"reduceSumSquare", {{0, "input"}}}}, + {"GatherND", {"gatherND", {{0, "input"}, {1, "indices"}}}}, + {"GreaterOrEqual", {"greaterOrEqual", {{0, "a"}, {1, "b"}}}}, + {"Conv", {"conv2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", {{0, "input"}}}}, + {"GatherElements", {"gatherElements", {{0, "input"}, {1, "indices"}}}}, + {"ScatterND", {"scatterND", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"Where", {"where", {{0, "condition"}, {1, "trueValue"}, {2, "falseValue"}}}}, + {"ConvTranspose", {"convTranspose2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"QuantizeLinear", {"quantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"ScatterElements", {"scatterElements", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"LayerNormalization", {"layerNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"DequantizeLinear", {"dequantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"InstanceNormalization", {"instanceNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"GRU", {"gru", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}}}}, + {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "input_mean"}, {4, "input_var"}}}}, + {"LSTM", {"lstm", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}, {6, "initialCellState"}, {7, "peepholeWeight"}}}}, +}; +} // namespace webnn +} // namespace onnxruntime From 6433c06bae6a69a0c3efd9a59c093eb97ea836e1 Mon Sep 17 00:00:00 2001 From: Xiaofei Han Date: Wed, 28 May 2025 03:47:38 +0800 Subject: [PATCH 38/57] [Mac] Fix --use_xcode build with Nodejs binding (#24868) ### Description Currently, the XCode build with nodejs binding(`--use_xcode`) always fails on Mac. ``` ./build.sh --config Debug --use_xcode --use_webgpu --build_shared_lib --build_nodejs --parallel --compile_no_warning_as_error --skip_submodule_sync --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 --skip_tests ``` The root cause is that the dylib locates on `/Debug/Debug` not `/Debug` with using XCode generator. For other generator(e.g. make, ninja), the dylib locates on `/Debug` as expected. Mac pipeline can pass because they didn't use XCode generator. image --- js/node/CMakeLists.txt | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 52af5dc48a21a..aedb1e35158ef 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -92,9 +92,21 @@ if (WIN32) endif() message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_WIN_BIN_DIR}") endif() + +if (APPLE) + if (${ONNXRUNTIME_GENERATOR} MATCHES "Xcode") + set(ONNXRUNTIME_MAC_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}) + else() + set(ONNXRUNTIME_MAC_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}) + endif() + message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_MAC_BIN_DIR}") +endif() + # add libraries if (WIN32) target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_WIN_BIN_DIR}) +elseif (APPLE) + target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_MAC_BIN_DIR}) else() target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR}) endif() @@ -114,7 +126,7 @@ if (WIN32) file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.dll DESTINATION ${dist_folder}) elseif (APPLE) - file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib + file(COPY ${ONNXRUNTIME_MAC_BIN_DIR}/libonnxruntime.dylib DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN) elseif (UNIX) file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.so From f8c13937fa6929c2f574e06be45727c33bdde36a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 May 2025 15:37:45 -0700 Subject: [PATCH 39/57] Bump setuptools from 69.0.3 to 78.1.1 in /tools/ci_build/github/linux/docker/scripts (#24810) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [setuptools](https://github.com/pypa/setuptools) from 69.0.3 to 78.1.1.
Changelog

Sourced from setuptools's changelog.

v78.1.1

Bugfixes

  • More fully sanitized the filename in PackageIndex._download. (#4946)

v78.1.0

Features

  • Restore access to _get_vc_env with a warning. (#4874)

v78.0.2

Bugfixes

  • Postponed removals of deprecated dash-separated and uppercase fields in setup.cfg. All packages with deprecated configurations are advised to move before 2026. (#4911)

v78.0.1

Misc

v78.0.0

Bugfixes

  • Reverted distutils changes that broke the monkey patching of command classes. (#4902)

Deprecations and Removals

  • Setuptools no longer accepts options containing uppercase or dash characters in setup.cfg.

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=setuptools&package-manager=pip&previous-version=69.0.3&new-version=78.1.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/ci_build/github/linux/docker/scripts/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index ee5cedb73ff04..7c1731aef992d 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,7 +3,7 @@ numpy==1.24.4 ; python_version < '3.9' numpy==2.1.2; python_version >= '3.9' mypy pytest -setuptools==69.0.3 +setuptools==78.1.1 wheel==0.42.0 onnx==1.17.0 ; python_version < '3.13' argparse From 9349c3703c37bcab79624ce4dc25f840b1d1f8f9 Mon Sep 17 00:00:00 2001 From: keshavv27 <165012837+keshavv27@users.noreply.github.com> Date: Wed, 28 May 2025 17:17:41 +0000 Subject: [PATCH 40/57] [onnxruntimeperftest] Add option to enable IO bindings on CUDA before session run (#24672) ### Description Add option to enable tensor input and output bindings on CUDA before perftest inference session run. Output binding is handled by changing the memory allocator type to CUDA. Input binding is handled by creating default ORT tensor on CPU, initializing it with data, then cudaMemcpy the data from CPU to CUDA allocated GPU tensor using the raw pointers. ### Motivation and Context By this change, the end-to-end inference time reported is more accurate as the CPU<->GPU overhead is moved out of the inference run --- cmake/onnxruntime_unittests.cmake | 3 + .../test/perftest/command_args_parser.cc | 6 +- onnxruntime/test/perftest/ort_test_session.cc | 56 +++++++++++++++++-- .../test/perftest/test_configuration.h | 1 + 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b31fdd4ea1ee1..6d409e1ee167d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1278,6 +1278,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal) endif() target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_NV OR onnxruntime_USE_TENSORRT) + target_link_libraries(onnxruntime_perf_test PRIVATE CUDA::cudart) + endif() if(WIN32) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) endif() diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 103da5f534ea7..b63ef7959e1db 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -61,6 +61,7 @@ namespace perftest { "\t-u [optimized_model_path]: Specify the optimized model path for saving.\n" "\t-d [CUDA only][cudnn_conv_algorithm]: Specify CUDNN convolution algorithms: 0(benchmark), 1(heuristic), 2(default). \n" "\t-q [CUDA only] use separate stream for copy. \n" + "\t-g [TensorRT RTX | TensorRT | CUDA] Enable tensor input and output bindings on CUDA before session run \n" "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" @@ -189,7 +190,7 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlgR:"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -389,6 +390,9 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, case 'R': test_config.run_config.register_custom_op_path = optarg; break; + case 'g': + test_config.run_config.enable_cuda_io_binding = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 8257cbfaa7f95..46e167b2ef823 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -19,6 +19,10 @@ #include "TestCase.h" #include "strings_helper.h" +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) +#include +#endif + #ifdef USE_OPENVINO #include "nlohmann/json.hpp" #endif @@ -145,6 +149,9 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "\nSupported options are:\n", options); } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("CUDA is not supported in this build\n"); #endif @@ -188,12 +195,18 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test session_options.AppendExecutionProvider_CUDA(cuda_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("TensorRT is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kNvTensorRTRTXExecutionProvider) { #ifdef USE_NV session_options.AppendExecutionProvider("NvTensorRtRtx", provider_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("NV TensorRT RTX is not supported in this build\n"); #endif @@ -855,7 +868,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); return Ort::Value(nullptr); }; } else { - Ort::MemoryInfo memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + Ort::MemoryInfo memory_info(nullptr); // Default initialize, will be overwritten + if (device_memory_name_ == CUDA) { + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeDefault); + } else { + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + } custom_allocator_ = Ort::Allocator(session_, memory_info); allocator_ = custom_allocator_; @@ -956,6 +974,7 @@ static void InitializeTensorWithSeed(int32_t seed, Ort::Value& tensor) { } bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { + Ort::AllocatorWithDefaultOptions default_allocator; // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); @@ -967,10 +986,37 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { auto transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; std::transform(input_node_dim.begin(), input_node_dim.end(), input_node_dim.begin(), transform_fcn); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); + if (device_memory_name_ != CUDA) { + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } +// Create tensor on CPU, initialize and copy to CUDA tensor +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) + else { + Ort::Value default_tensor = Ort::Value::CreateTensor(default_allocator, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, default_tensor); + + // Get pointer to CPU tensor data + const void* default_ptr = default_tensor.GetTensorRawData(); + + size_t total_bytes = default_tensor.GetTensorSizeInBytes(); + + Ort::Value cuda_tensor = Ort::Value::CreateTensor(allocator_, input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + + void* cuda_ptr = cuda_tensor.GetTensorMutableData(); + + // Copy the initialized data from CPU to GPU + cudaError_t cuda_err = cudaMemcpy(cuda_ptr, default_ptr, total_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + ORT_THROW("Failed to copy tensor data from CPU to CUDA device. CUDA Error: ", cudaGetErrorString(cuda_err)); + } + PreLoadTestData(0, i, std::move(cuda_tensor)); + } +#endif } } return true; diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 90759a4d2f65a..e180efca5b9db 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -66,6 +66,7 @@ struct RunConfig { bool disable_spinning_between_run = false; bool exit_after_session_creation = false; std::basic_string register_custom_op_path; + bool enable_cuda_io_binding{false}; }; struct PerformanceTestConfig { From 801006d80f51f70af7cd9fa9acb3f55e161ad685 Mon Sep 17 00:00:00 2001 From: minfhong-quic Date: Thu, 29 May 2025 06:00:02 +0800 Subject: [PATCH 41/57] [QNN-EP] Define SpaceToDepth fusion for YOLOv2. (#24848) ### Description - Add SpaceToDepth fusion for QNN preprocess. - The pattern in YOLOv2 is uncommon while the common seen one is left as future work. - Add entry point/API for non-quantization user to preprocess models for QNN execution. - Revise cmake to package newly introduced directory into Python wheel. ### Motivation and Context - While executing YOLOv2 model on QNN-EP, a sequence of Reshape and Transpose having 6D shapes are falling back to CPU due to HTP limitation. Add fusion to fuse this sequence of ops into a single SpaceToDepth which can be directly executed on QNN-EP. - Since current QNN preprocess is provided in `onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py` which is under quantization directory, the path may be confusing for non-quantization users. In order to allow non-quantization users to preprocess models for QNN, introduce `onnxruntime/python/tools/qnn/preprocess.py` to serve as the entry point and provide API to preprocess models. --- cmake/onnxruntime_python.cmake | 7 + onnxruntime/python/tools/qnn/preprocess.py | 139 +++++++++++++++ .../qnn/fusion_spacetodepth.py | 162 ++++++++++++++++++ .../execution_providers/qnn/preprocess.py | 8 + setup.py | 1 + 5 files changed, 317 insertions(+) create mode 100644 onnxruntime/python/tools/qnn/preprocess.py create mode 100644 onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index cf5a6c78b925c..f6eac2c24eca2 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -453,6 +453,9 @@ endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/*.py" ) +file(GLOB onnxruntime_python_tools_qnn_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/qnn/*.py" +) file(GLOB onnxruntime_python_quantization_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/*.py" ) @@ -564,6 +567,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/qdq_helpers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/qnn COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bart @@ -649,6 +653,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_directory ${ONNXRUNTIME_ROOT}/core/flatbuffers/ort_flatbuffers_py $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_tools_qnn_src} + $/onnxruntime/tools/qnn/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_src} $/onnxruntime/quantization/ diff --git a/onnxruntime/python/tools/qnn/preprocess.py b/onnxruntime/python/tools/qnn/preprocess.py new file mode 100644 index 0000000000000..b7ddf1de9dc34 --- /dev/null +++ b/onnxruntime/python/tools/qnn/preprocess.py @@ -0,0 +1,139 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Provide entry point to preprocess ONNX model especially for QNN.""" + +import argparse +import pathlib + +import onnx + +from onnxruntime.quantization.execution_providers import qnn + + +def _parse_arguments(): + """Parse cmdline arguments.""" + parser = argparse.ArgumentParser(description="Arguments for QNN model preprocess.") + + parser.add_argument("--input_model_path", "-i", required=True, help="Path to the input ONNX model.") + parser.add_argument("--output_model_path", "-o", required=True, help="Path to the output ONNX model.") + + # Save preprocessed model with external data. + parser.add_argument( + "--save_as_external_data", + action="store_true", + help="Whether the output model would be saved with external data.", + ) + parser.add_argument( + "--all_tensors_to_one_file", + action="store_true", + help="Whether to save all external data in one file or save each tensor to a file named with the tensor name.", + ) + parser.add_argument( + "--external_data_location", + help="Filename of the external file where all tensors are saved. The path is relative to the model path.", + ) + parser.add_argument( + "--external_data_size_threshold", + default=1024, + type=int, + help="Tensors with data size larger than this threshold are converted to external data.", + ) + parser.add_argument( + "--external_data_convert_attribute", + action="store_true", + help="Whether to save all tensors, including attribute tensors, to external data.", + ) + + # Preprocess options. + parser.add_argument( + "--fuse_layernorm", + action="store_true", + help="Whether to fuse matched sequences into LayerNormalization nodes if possible.", + ) + + # I/O layouts. + parser.add_argument( + "--inputs_to_make_channel_last", + nargs="+", + default=None, + help="List of graph input names to be transposed into channel-last.", + ) + + parser.add_argument( + "--outputs_to_make_channel_last", + nargs="+", + default=None, + help="List of graph output names to be transposed into channel-last.", + ) + + return parser.parse_args() + + +def qnn_preprocess_model( + model_input: str | pathlib.Path | onnx.ModelProto, + model_output: str | pathlib.Path, + fuse_layernorm: bool = False, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: str | None = None, + external_data_size_threshold: int = 1024, + external_data_convert_attribute: bool = False, + inputs_to_make_channel_last: list[str] | None = None, + outputs_to_make_channel_last: list[str] | None = None, +) -> bool: + """Preprocess ONNX model for QNN. + + Args: + model_input: A path or ONNX ModelProto specifiying the model to be preprocessed. + model_output: A path specifying where the preprocessed model to be saved. + fuse_layernorm: A bool specifying whether to fuse the matched sequence into a single LayerNormalization node. + Defaults to False. + save_as_external_data: A bool specifying whether to save model with external data. Defaults to False. + all_tensors_to_one_file: A bool specifying whether to save all external data in one file or save each tensor to + a file named with the tensor name. This argument is effective only when `save_as_external_data` is True. + Defaults to False. + external_data_location: A str specifying where to save the external data. The path is relative to the model + path. This argument is effective only when `save_as_external_data` is True. Defaults to the model name. + external_data_size_threshold: An int specifying the threshold of data size for tensors be saved as external + data. This argument is effective only when `save_as_external_data` is True. Defaults to 1024. + external_data_convert_attribute: A bool specifying whether to save all tensors including attributes as external + data. This argument is effective only when `save_as_external_data` is True. Defaults to False. + inputs_to_make_channel_last: A list of strs specifying graph input names to be transposed into channel-last. + Defaults to None. + outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last. + Defaults to None. + + Returns: + A bool indicating whether the model is modified. + """ + return qnn.qnn_preprocess_model( + model_input, + model_output, + fuse_layernorm=fuse_layernorm, + save_as_external_data=save_as_external_data, + all_tensors_to_one_file=all_tensors_to_one_file, + external_data_location=external_data_location, + external_data_size_threshold=external_data_size_threshold, + external_data_convert_attribute=external_data_convert_attribute, + inputs_to_make_channel_last=inputs_to_make_channel_last, + outputs_to_make_channel_last=outputs_to_make_channel_last, + ) + + +if __name__ == "__main__": + args = _parse_arguments() + qnn_preprocess_model( + args.input_model_path, + args.output_model_path, + fuse_layernorm=args.fuse_layernorm, + save_as_external_data=args.save_as_external_data, + all_tensors_to_one_file=args.all_tensors_to_one_file, + external_data_location=args.external_data_location, + external_data_size_threshold=args.external_data_size_threshold, + external_data_convert_attribute=args.external_data_convert_attribute, + inputs_to_make_channel_last=args.inputs_to_make_channel_last, + outputs_to_make_channel_last=args.outputs_to_make_channel_last, + ) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py new file mode 100644 index 0000000000000..ce92b3e2a1d76 --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py @@ -0,0 +1,162 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Define SpaceToDepth fusion.""" + +import onnx + +from ... import fusions, onnx_model + + +class FusionSpaceToDepth(fusions.Fusion): + """Fusion for SpaceToDepth.""" + + def __init__(self, model: onnx_model.ONNXModel): + """Initialize. + + Args: + model: An onnx_model.ONNXModel instance. + """ + super().__init__(model, "SpaceToDepth", "Reshape") + + def _fuse_yolo( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """Fuse for early version of YOLO. + + Pattern: + + | [N, C, H, W] + Reshape + | [N, C, H/blk, blk, W/blk, blk] + Transpose + | [N, C, H/blk, W/blk, blk, blk] + Reshape + | [N, C, H/blk * W/blk, blk * blk] + Transpose + | [N, C, blk * blk, H/blk * W/blk] + Reshape + | [N, C, blk * blk, H/blk, W/blk] + Transpose + | [N, blk * blk, C, H/blk, W/blk] + Reshape + | [N, blk * blk * C, H/blk, W/blk] + + This sequence can be fused into a single SpaceToDepth with blocksize `blk`. Note that unlike DepthToSpace + supporting DCR or CRD mode, SpaceToDepth only supports DCR mode in its latest opset version (13), which matches + the pattern here. + """ + reshape_node1 = node + + def get_target_child(parent_node, target_op_type): + """Get target child of given node.""" + if parent_node.output[0] not in input_name_to_nodes: + return None + + children = input_name_to_nodes[parent_node.output[0]] + if len(children) > 1 or children[0].op_type != target_op_type: + return None + + return children[0] + + if ( + (transpose_node1 := get_target_child(reshape_node1, "Transpose")) is None + or (reshape_node2 := get_target_child(transpose_node1, "Reshape")) is None + or (transpose_node2 := get_target_child(reshape_node2, "Transpose")) is None + or (reshape_node3 := get_target_child(transpose_node2, "Reshape")) is None + or (transpose_node3 := get_target_child(reshape_node3, "Transpose")) is None + or (reshape_node4 := get_target_child(transpose_node3, "Reshape")) is None + ): + return False + + def get_tensor_shape(tensor_name): + """Get shape for given tensor name.""" + tensor_type = self.model.get_tensor_type(tensor_name) + if not tensor_type: + return None + + tensor_shape = self.tensor_shape_to_list(tensor_type) + if not tensor_shape: + return None + + return tensor_shape + + if ( + (input_shape := get_tensor_shape(reshape_node1.input[0])) is None + or (reshape_shape1 := get_tensor_shape(reshape_node1.output[0])) is None + or (reshape_shape2 := get_tensor_shape(reshape_node2.output[0])) is None + or (reshape_shape3 := get_tensor_shape(reshape_node3.output[0])) is None + or (reshape_shape4 := get_tensor_shape(reshape_node4.output[0])) is None + ): + return False + + transpose_perm1 = self.get_node_attribute(transpose_node1, "perm") + transpose_perm2 = self.get_node_attribute(transpose_node2, "perm") + transpose_perm3 = self.get_node_attribute(transpose_node3, "perm") + + # Check rank. + if ( + len(input_shape) != 4 + or len(reshape_shape1) != 6 + or len(reshape_shape2) != 4 + or len(reshape_shape3) != 5 + or len(reshape_shape4) != 4 + ): + return False + + # Check shape and perm. + batch, channel, height, width = input_shape + blocksize = reshape_shape1[3] + if ( + reshape_shape1 != [batch, channel, height // blocksize, blocksize, width // blocksize, blocksize] + or transpose_perm1 != [0, 1, 2, 4, 3, 5] + or reshape_shape2 != [batch, channel, (height // blocksize) * (width // blocksize), blocksize**2] + or transpose_perm2 != [0, 1, 3, 2] + or reshape_shape3 != [batch, channel, blocksize**2, height // blocksize, width // blocksize] + or transpose_perm3 != [0, 2, 1, 3, 4] + or reshape_shape4 != [batch, blocksize**2 * channel, height // blocksize, width // blocksize] + ): + return False + + self.nodes_to_remove.extend( + [ + reshape_node1, + transpose_node1, + reshape_node2, + transpose_node2, + reshape_node3, + transpose_node3, + reshape_node4, + ] + ) + + s2d_node = onnx.helper.make_node( + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[reshape_node1.input[0]], + outputs=[reshape_node4.output[0]], + blocksize=blocksize, + ) + self.nodes_to_add.append(s2d_node) + + return True + + def fuse( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """Fuse a sequence of Reshape and Transpose nodes into a single SpaceToDepth node. + + Args: + node: An onnx.NodeProto matching the specified search type (i.e., Reshape). + input_name_to_nodes: A dict mapping tensor name to consumed nodes. + output_name_to_node: A dict mapping tensor name to produced node. + """ + self._fuse_yolo(node, input_name_to_nodes, output_name_to_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 85f5d967f9ee3..44ff7e4aba10b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -12,7 +12,9 @@ from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel +from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization +from .fusion_spacetodepth import FusionSpaceToDepth def qnn_preprocess_model( @@ -83,6 +85,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) + model = save_and_reload_model_with_shape_infer(model) onnx_model = ONNXModel(model) # Fuse Erf sequence into a single Gelu @@ -95,6 +98,11 @@ def qnn_preprocess_model( if fusion_lpnorm.apply(): modified = True + # Fuse Reshape/Transpose sequence into a single SpaceToDepth. + fusion_s2d = FusionSpaceToDepth(onnx_model) + if fusion_s2d.apply(): + modified = True + # Optionally, fuse ReduceMean sequence into a single LayerNormalization node. if fuse_layernorm: onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") diff --git a/setup.py b/setup.py index b30d0e1e6aea4..6a1f126476158 100644 --- a/setup.py +++ b/setup.py @@ -514,6 +514,7 @@ def finalize_options(self): "onnxruntime.tools.ort_format_model.ort_flatbuffers_py", "onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs", "onnxruntime.tools.qdq_helpers", + "onnxruntime.tools.qnn", "onnxruntime.quantization", "onnxruntime.quantization.operators", "onnxruntime.quantization.CalTableFlatBuffers", From f9739c2da7e86e8c91058a2b934fe825e03d94b3 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Wed, 28 May 2025 15:06:25 -0700 Subject: [PATCH 42/57] [QNN EP] Fuse scale into softmax (#24809) QNN [Softmax op defines pre-scale (`beta`)](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html#softmax) that we can fold constant scalar multiply into it. --- cmake/onnxruntime_unittests.cmake | 1 + .../core/optimizer/bias_softmax_fusion.cc | 2 +- .../builder/qnn_node_group/qnn_node_group.cc | 2 + .../qnn_node_group/scale_softmax_fusion.cc | 226 ++++++++++++++++++ .../qnn_node_group/scale_softmax_fusion.h | 54 +++++ .../scale_softmax_fusion_test.cc | 147 ++++++++++++ 6 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h create mode 100644 onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6d409e1ee167d..15cc238173f29 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -724,6 +724,7 @@ endif() # or reduced op builds. if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn) if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index bcbb70ba8fac5..2bbc70db16cde 100644 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -135,7 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node new_axis = (int)HandleNegativeAxis(axis, rank); // The axis attribute for Softmax in OpSet-11 and OpSet-13 are different. - // Details in function documentatin. + // Details in function documentation. if (is_since_opset_13 && new_axis != rank - 1) return false; int singlebatch_rank = rank - new_axis; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 20b37a2fb2b22..839079e6c1a8e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -15,6 +15,7 @@ #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -90,6 +91,7 @@ static std::unique_ptr TryQnnFusions( {"DequantizeLinear", DQQFusion::TryFusion}, {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, {"Gemm", ReshapeGemmFusion::TryFusion}, + {"Mul", ScaleSoftmaxFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc new file mode 100644 index 0000000000000..5c7091b3be3cc --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" + +namespace onnxruntime { +namespace qnn { +namespace { + +constexpr char kOpMul[] = "Mul"; +constexpr char kOpSoftmax[] = "Softmax"; + +/// @brief Get the index of the scalar input in the mul node +/// @param mul Multiply node unit +/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt +std::optional GetMulScalarInputIndex(const NodeUnit* mul) { + const NodeArg* mul_y = mul->GetNode().InputDefs()[1]; + const NodeArg* mul_x = mul->GetNode().InputDefs()[0]; + auto y_shape_proto = mul_y->Shape(); + auto x_shape_proto = mul_x->Shape(); + bool is_y_scalar = false; + if (y_shape_proto != nullptr) { + auto y_shape = utils::GetTensorProtoShape(*y_shape_proto); + is_y_scalar = y_shape.NumDimensions() == 0; + } + bool is_x_scalar = false; + if (x_shape_proto != nullptr) { + auto x_shape = utils::GetTensorProtoShape(*x_shape_proto); + is_x_scalar = x_shape.NumDimensions() == 0; + } + if (is_y_scalar) { + return 1U; + } else if (is_x_scalar) { + return 0U; + } + return std::nullopt; +} + +/// @brief Get the axis for softmax +/// @param mul Multiply node unit +/// @param softmax Softmax node unit +/// @return The axis for softmax +std::optional GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) { + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); + std::optional param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS); + if (!param_axis.has_value()) { + return std::nullopt; + } + int64_t axis_value = param_axis.value(); + if (axis_value < 0) { + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); + size_t input_other_index = 1U - input_scale_index; + int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size(); + axis_value += static_cast(rank); + } + return static_cast(axis_value); +} + +/// @brief Identify scalar input from mul node if present +/// @param mul Multiply node unit +/// @return The scalar input float value if found, otherwise std::nullopt +std::optional ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) { + std::optional input_scale_index = GetMulScalarInputIndex(mul); + if (!input_scale_index.has_value()) { + return std::nullopt; + } + const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()]; + if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) { + return std::nullopt; + } + const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name()); + if (!scalar_tensor) { + return std::nullopt; + } + if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return std::nullopt; + } + const auto& raw_data = scalar_tensor->raw_data(); + if (raw_data.size() != sizeof(float) || reinterpret_cast(raw_data.data()) % alignof(float) != 0) { + return std::nullopt; + } + return *reinterpret_cast(raw_data.data()); +} + +/// @brief Create or validate the QNN node +/// @param qnn_model_wrapper QNN model wrapper +/// @param node_units The node units containing the softmax and mul nodes +/// @param validate Whether to validate the QNN node +/// @return Status +Status CreateOrValidateOnQnn( + QnnModelWrapper* qnn_model_wrapper, + gsl::span node_units, + bool validate) { + const NodeUnit* mul = node_units[0]; + const NodeUnit* softmax = node_units[1]; + ORT_RETURN_IF_NOT(mul->OpType() == kOpMul, + "Expected scale node to be of type Mul, got ", mul->OpType()); + ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax, + "Expected softmax node to be of type Softmax, got ", softmax->OpType()); + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); + size_t input_other_index = 1U - input_scale_index; + const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index]; + const NodeUnitIODef& softmax_output = softmax->Outputs()[0]; + + std::vector param_tensor_names; + { // axis + std::optional axis = GetPositiveSoftmaxAxis(mul, softmax); + if (axis.has_value()) { + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; + axis_scalar.dataType = QNN_DATATYPE_UINT_32; + axis_scalar.uint32Value = axis.value(); + QnnParamWrapper param_wrapper(softmax->Index(), + softmax->Name(), + QNN_OP_SOFTMAX_PARAM_AXIS, + axis_scalar); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); + } + } + { // beta + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); + std::optional beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA); + float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f); + Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT; + beta_scalar.dataType = QNN_DATATYPE_FLOAT_32; + beta_scalar.floatValue = scale * beta.value_or(1.0f); + QnnParamWrapper param_wrapper(softmax->Index(), + softmax->Name(), + QNN_OP_SOFTMAX_PARAM_BETA, + beta_scalar); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); + } + + QnnTensorWrapper fused_softmax_input; + QnnTensorWrapper fused_softmax_output; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_SOFTMAX, + {fused_softmax_input.GetQnnTensor()}, + {fused_softmax_output.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_SOFTMAX, + {mul_input_other.node_arg.Name()}, + {softmax_output.node_arg.Name()}, + std::move(param_tensor_names), + validate), + "Failed to add fused " + std::string(kOpSoftmax) + " node."); + } + return Status::OK(); +} + +} // namespace + +std::unique_ptr ScaleSoftmaxFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& mul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + [[maybe_unused]] const logging::Logger& logger) { + if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + // Check if the mul node has a scalar input that can fold into the softmax's beta + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + std::optional scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit); + if (!scalar.has_value()) { + return nullptr; + } + + // Mul node must have a single Softmax node as child + const std::array child_op_types{kOpSoftmax}; + const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + if (softmax == nullptr) { + return nullptr; + } + + std::array node_unit_array{&mul_node_unit, softmax}; + auto node_units = gsl::make_span(node_unit_array.data(), 2); + if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) { + return nullptr; + } + return std::make_unique(node_units); +} + +gsl::span ScaleSoftmaxFusion::GetNodeUnits() const { + return gsl::span{node_units_.data(), node_units_.size()}; +} + +Status ScaleSoftmaxFusion::IsSupported( + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true); +} + +Status ScaleSoftmaxFusion::AddToModelBuilder( + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h new file mode 100644 index 0000000000000..66eb892e7a884 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of pattern: Softmax(Mul(x, scalar_scale)) => QnnSoftmax(x, beta=scalar_scale) +/// +class ScaleSoftmaxFusion : public IQnnNodeGroup { + public: + explicit ScaleSoftmaxFusion(gsl::span node_units) { + ORT_ENFORCE(node_units.size() == 2, "Pattern expect exactly 2 NodeUnits."); + node_units_[0] = node_units[0]; + node_units_[1] = node_units[1]; + } + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScaleSoftmaxFusion); + + Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override { return node_units_[1]; } + std::string_view Type() const override { return "ScaleSoftmaxFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Softmax -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the Softmax and Mul NodeUnits. + /// + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& mul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc new file mode 100644 index 0000000000000..aa8dc492a95c9 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +namespace { + +GetTestModelFn BuildTestCaseScalar( + const TestInputDef& input_def, + float scale_value, + bool use_constant, + bool reverse_input_order, + std::optional softmax_axis = std::nullopt) { + return [&](ModelTestBuilder& builder) -> void { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* scale{nullptr}; + if (use_constant) { + onnx::TensorProto scale_value_proto; + scale_value_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + utils::SetRawDataInTensorProto(scale_value_proto, reinterpret_cast(&scale_value), sizeof(float)); + scale = builder.MakeIntermediate(); + builder.AddNode("Constant", {}, {scale}).AddAttribute("value", scale_value_proto); + } else { + scale = builder.MakeScalarInitializer(scale_value); + } + NodeArg* intermediate = builder.MakeIntermediate(); + auto mul_inputs = reverse_input_order ? std::vector{scale, input} : std::vector{input, scale}; + builder.AddNode("Mul", mul_inputs, {intermediate}); + Node& softmax = builder.AddNode("Softmax", {intermediate}, {builder.MakeOutput()}); + if (softmax_axis.has_value()) { + softmax.AddAttribute("axis", softmax_axis.value()); + } + }; +} + +GetTestModelFn BuildTestCaseNoScalar(const TestInputDef& input_def1, const TestInputDef& input_def2) { + return [&input_def1, input_def2](ModelTestBuilder& builder) -> void { + NodeArg* input = MakeTestInput(builder, input_def1); + NodeArg* scale = MakeTestInput(builder, input_def2); + NodeArg* intermediate = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, scale}, {intermediate}); + builder.AddNode("Softmax", {intermediate}, {builder.MakeOutput()}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializer) { + ProviderOptions provider_options = GetProviderOptions(); + + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, /*use_constant=*/false, /*reverse_input_order=*/false), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstant) { + ProviderOptions provider_options = GetProviderOptions(); + + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.375f, /*use_constant=*/true, /*reverse_input_order=*/false), + provider_options, + /*opset_version=*/14, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializerReversed) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.375f, /*use_constant=*/false, /*reverse_input_order=*/true), + provider_options, + /*opset_version=*/15, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstantReversed) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, /*use_constant=*/true, /*reverse_input _order=*/true), + provider_options, + /*opset_version=*/16, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSoftmaxNegativeAxis) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, + /*use_constant=*/true, /*reverse_input_order=*/true, /*softmax_axis=*/-1), + provider_options, + /*opset_version=*/22, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSkipNoScalar4d) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def1 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + auto input_def2 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseNoScalar(input_def1, input_def2), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSkipNoScalar1d) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def1 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + auto input_def2 = TestInputDef({1}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseNoScalar(input_def1, input_def2), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) From b17ab8e543a9a2047e211a2f8cde4b790c3ee0fd Mon Sep 17 00:00:00 2001 From: qc-tbhardwa Date: Thu, 29 May 2025 22:33:46 +0530 Subject: [PATCH 43/57] Download protobuf dependency on ARM64 build host (#24847) Windows on ARM support AMD64 emulation, so we can use win64 version of protoc. Description Compilation on ARM64 machine fails due to missing protoc dependency. Motivation and Context With this change we can compile onnxruntime on Windows on Arm devices without setting protobuf manually. CMake will download and setup protoc dependency. --- cmake/external/onnxruntime_external_deps.cmake | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e2b0432f3e8e1..4f6bcc8c90419 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -132,6 +132,9 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG) elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32} EXCLUDE_FROM_ALL) FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "ARM64") + onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64} EXCLUDE_FROM_ALL) + FetchContent_Populate(protoc_binary) endif() if(protoc_binary_SOURCE_DIR) From b925dfd74745ecd1b0ea7e6e038069d6d7667ec7 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 29 May 2025 10:26:27 -0700 Subject: [PATCH 44/57] Disable VCPKG's binary cache (#24889) VCPKG has removed this feature. --- .../locate-vcvarsall-and-setup-env/action.yml | 2 +- .github/workflows/android.yml | 4 +- .../linux-wasm-ci-build-and-test-workflow.yml | 2 +- .github/workflows/linux_cuda_ci.yml | 4 +- .github/workflows/linux_minimal_build.yml | 52 +++++++++---------- .github/workflows/linux_tensorrt_ci.yml | 4 +- .github/workflows/reusable_linux_build.yml | 8 +-- tools/ci_build/build.py | 1 - 8 files changed, 38 insertions(+), 39 deletions(-) diff --git a/.github/actions/locate-vcvarsall-and-setup-env/action.yml b/.github/actions/locate-vcvarsall-and-setup-env/action.yml index 3066721e797ea..bf1016bf2265b 100644 --- a/.github/actions/locate-vcvarsall-and-setup-env/action.yml +++ b/.github/actions/locate-vcvarsall-and-setup-env/action.yml @@ -14,7 +14,7 @@ runs: steps: - name: Setup VCPKG - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 69ff9a1cec976..092b6fc8f5ce5 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -37,7 +37,7 @@ jobs: ndk-version: 28.0.13004108 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -131,7 +131,7 @@ jobs: architecture: x64 - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index d74d9e9a4f0bf..ef2f35aecdb3b 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -42,7 +42,7 @@ jobs: with: python-version: "3.12" architecture: ${{ env.buildArch }} - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.03.19' vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 0dbe63371c7b8..38526e7a5c00f 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -50,7 +50,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -93,7 +93,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index e68ef56cdb1ce..5f90d9430342e 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -43,7 +43,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -53,7 +53,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.7 - name: Upload Test Data Artifact uses: actions/upload-artifact@v4 @@ -80,7 +80,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -98,7 +98,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 2 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -113,7 +113,7 @@ jobs: --enable_training_ops - name: Run Build 2 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -151,7 +151,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -161,7 +161,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: reduced-ops-config-file: required_ops.ort_models.config enable-custom-ops: 'true' @@ -191,7 +191,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -200,7 +200,7 @@ jobs: add-cmake-to-path: 'true' disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: reduced-ops-config-file: required_ops_and_types.ort_models.config enable-type-reduction: 'true' @@ -229,7 +229,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -239,7 +239,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: globally_allowed_types: 'bool,float,int8_t,uint8_t' enable-type-reduction: 'true' @@ -264,7 +264,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -282,7 +282,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 5 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -295,7 +295,7 @@ jobs: --minimal_build extended - name: Run Build 5 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -307,7 +307,7 @@ jobs: --use_binskim_compliant_compile_flags --minimal_build extended - name: Run Build 5 (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -334,7 +334,7 @@ jobs: submodules: false - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -358,7 +358,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6a (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -374,7 +374,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6a (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -391,7 +391,7 @@ jobs: - name: Run Build 6a (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -427,7 +427,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -445,7 +445,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 6b (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -464,7 +464,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6b (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -503,7 +503,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -526,7 +526,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6c (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -545,7 +545,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6c (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -588,7 +588,7 @@ jobs: path: ${{ runner.temp }}/.test_data/ - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 405de75e95454..1df467043329a 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -52,7 +52,7 @@ jobs: # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -95,7 +95,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 27595254800f9..af24e3a3d901a 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -83,7 +83,7 @@ jobs: python-version: ${{ inputs.python_version }} - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} @@ -103,7 +103,7 @@ jobs: # ------------- Update Step (CMake Generation) ------------- - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: update_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -115,7 +115,7 @@ jobs: # ------------- Build Step (Compilation) ------------- - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: build_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -128,7 +128,7 @@ jobs: - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: test_step if: inputs.run_tests == true - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8dce6be731402..82372645d364f 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -318,7 +318,6 @@ def generate_vcpkg_install_options(build_dir, args): elif "RUNNER_TEMP" in os.environ: temp_dir = os.environ["RUNNER_TEMP"] vcpkg_install_options.append(f"--x-buildtrees-root={temp_dir}") - vcpkg_install_options.append("--binarysource=clear\\;x-gha,readwrite") # Config asset cache if args.use_vcpkg_ms_internal_asset_cache: From 49f51799cab6755a80feff437dff7d504e6380c9 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 29 May 2025 14:15:15 -0700 Subject: [PATCH 45/57] Fix symbol publishing (#24879) **Description:** This pull request refactors the symbol publishing workflow that uses the internal REST API. It addresses a breaking change introduced by the `Az.Accounts` module update (v5.0.1+) where `Get-AzAccessToken` now returns a `SecureString`. Additionally, it improves the structure and robustness of the custom symbol publishing steps. **Problem:** 1. The pipeline recently stopped working due to an update in the Azure PowerShell `Az.Accounts` module. The `Get-AzAccessToken` cmdlet now returns a `SecureString` by default, which was incompatible with the previous script that expected a plain string token when setting a pipeline variable. 2. The previous implementation used two separate tasks: one `AzurePowerShell@5` task to generate the token and set it as a pipeline variable, and a subsequent `pwsh` task to consume this variable and make REST API calls. This separation required converting the `SecureString` to plain text before setting the pipeline variable. **Solution:** To address these issues and improve the pipeline's design: 1. The "Generate an Azure Token" (`AzurePowerShell@5`) task and the "Publish Symbols using internal REST API" (`pwsh`) task have been **combined into a single `AzurePowerShell@5` task.** 2. Within this unified task: * `Get-AzAccessToken` is called, and its `SecureString` output is stored in a local PowerShell variable. * The `SecureString` token is converted to plain text *only within the scope of this script* and immediately before it's used in the `Authorization` header for `Invoke-RestMethod` calls. * The token is no longer passed between tasks via a pipeline variable, enhancing security by limiting the scope of the plain text token. **Key Changes:** * **Enhanced `SecureString` Management:** The token remains a `SecureString` for most of its lifetime within the script, reducing exposure. * **Improved Error Handling:** `try-catch` blocks have been added around the token retrieval and `Invoke-RestMethod` calls for better error reporting and pipeline stability. * **Robust Parameter Handling:** Explicit conversion for boolean parameters (e.g., `includePublicSymbolServer`) to ensure correct PowerShell boolean types before JSON serialization. --- .../publish-symbolrequestprod-api.yml | 105 ++++++++++++------ 1 file changed, 70 insertions(+), 35 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml index b2a3eaca0280f..9f0230c4b1141 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml @@ -1,4 +1,3 @@ -# This file was copied from https://github.com/microsoft/devhome/blob/main/build/templates/publish-symbolrequestprod-api.yml#L71 parameters: - name: includePublicSymbolServer type: boolean @@ -32,18 +31,6 @@ steps: Install-Module -Verbose -AllowClobber -Force Az.Accounts, Az.Storage, Az.Network, Az.Resources, Az.Compute displayName: Install Azure Module Dependencies - # Transit the Azure token from the Service Connection into a secret variable for the rest of the pipeline to use. - - task: AzurePowerShell@5 - displayName: Generate an Azure Token - inputs: - azureSubscription: ${{ parameters.subscription }} - azurePowerShellVersion: LatestVersion - pwsh: true - ScriptType: InlineScript - Inline: |- - $AzToken = (Get-AzAccessToken -ResourceUrl api://30471ccf-0966-45b9-a979-065dbedb24c1).Token - Write-Host "##vso[task.setvariable variable=SymbolAccessToken;issecret=true]$AzToken" - - task: PublishSymbols@2 displayName: Publish Symbols (to current Azure DevOps tenant) continueOnError: True @@ -60,28 +47,76 @@ steps: env: LIB: $(Build.SourcesDirectory) - - pwsh: |- - # Prepare the defaults for IRM - $PSDefaultParameterValues['Invoke-RestMethod:Headers'] = @{ Authorization = "Bearer $(SymbolAccessToken)" } - $PSDefaultParameterValues['Invoke-RestMethod:ContentType'] = "application/json" - $PSDefaultParameterValues['Invoke-RestMethod:Method'] = "POST" + - task: AzurePowerShell@5 + displayName: Generate Token and Publish Symbols via REST API + inputs: + azureSubscription: ${{ parameters.subscription }} + azurePowerShellVersion: LatestVersion + pwsh: true + ScriptType: InlineScript + Inline: | + # Part 1: Generate an Azure Token + Write-Host "Attempting to retrieve Azure Access Token for symbol publishing API." + $apiResourceUrl = "api://30471ccf-0966-45b9-a979-065dbedb24c1" + try { + $secureTokenObject = (Get-AzAccessToken -ResourceUrl $apiResourceUrl).Token + Write-Host "Successfully retrieved a token object." + } + catch { + Write-Error "Failed to retrieve Azure Access Token. Error: $($_.Exception.Message)" + throw "Failed to retrieve Azure Access Token." # Fail the task + } + + # Convert the SecureString token to a plain text string for the HTTP header + # This is done just-in-time before its use. + $plainTextToken = $secureTokenObject | ConvertFrom-SecureString -AsPlainText + Write-Host "Token converted to plain text for API call (will not be logged)." + + # Part 2: Publish Symbols using internal REST API + Write-Host "Preparing to publish symbols using internal REST API." + + # Prepare the defaults for Invoke-RestMethod for this scope + $PSDefaultParameterValues = @{} # Initialize to ensure a clean state for default parameters + $PSDefaultParameterValues['Invoke-RestMethod:Headers'] = @{ Authorization = "Bearer $plainTextToken" } + $PSDefaultParameterValues['Invoke-RestMethod:ContentType'] = "application/json" + $PSDefaultParameterValues['Invoke-RestMethod:Method'] = "POST" # Default method for symbol request creation/update + + $baseUri = "https://symbolrequestprod.trafficmanager.net/projects/${{ parameters.symbolProject }}/requests" + + # Prepare and submit the symbol request creation + $expirationDate = (Get-Date).Add([TimeSpan]::FromDays(${{ parameters.symbolExpiryTime }})) + $createRequestBody = @{ + requestName = "${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}"; + expirationTime = $expirationDate.ToString(); + } + $requestNameForUri = $createRequestBody.requestName # Store for use in the next URI + + Write-Host "##[debug]Creating symbol request: Name '$($createRequestBody.requestName)', Expiration '$($createRequestBody.expirationTime)'. URI: '$baseUri'" + try { + Invoke-RestMethod -Uri $baseUri -Body ($createRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "Successfully initiated symbol request '$($createRequestBody.requestName)'." + } + catch { + Write-Error "Failed to create symbol request. Error: $($_.Exception.Message)" + # Optionally inspect response: $_.ErrorDetails.Message or $_.Exception.Response + throw "Failed to create symbol request." + } - $BaseUri = "https://symbolrequestprod.trafficmanager.net/projects/${{ parameters.symbolProject }}/requests" + # Prepare and submit the symbol publication details + $publishRequestBody = @{ + publishToInternalServer = $true; + publishToPublicServer = [System.Convert]::ToBoolean("${{ parameters.includePublicSymbolServer }}"); # Ensure YAML boolean is correctly PowerShell boolean + } + $publishUri = "$baseUri/$requestNameForUri" - # Prepare the request - $expiration = (Get-Date).Add([TimeSpan]::FromDays(${{ parameters.symbolExpiryTime }})) - $createRequestBody = @{ - requestName = "${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}"; - expirationTime = $expiration.ToString(); - } - Write-Host "##[debug]Starting request $($createRequestBody.requestName) with expiration date of $($createRequestBody.expirationTime)" - Invoke-RestMethod -Uri "$BaseUri" -Body ($createRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "##[debug]Submitting symbol publication details for request '$requestNameForUri'. URI: '$publishUri'. Payload: $($publishRequestBody | ConvertTo-Json -Compress)" + try { + Invoke-RestMethod -Uri $publishUri -Body ($publishRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "Successfully submitted symbol publication details for '$requestNameForUri'." + } + catch { + Write-Error "Failed to submit symbol publication details. Error: $($_.Exception.Message)" + throw "Failed to submit symbol publication details." + } - # Request symbol publication - $publishRequestBody = @{ - publishToInternalServer = $true; - publishToPublicServer = $${{ parameters.includePublicSymbolServer }}; - } - Write-Host "##[debug]Submitting request $($createRequestBody.requestName) ($($publishRequestBody | ConvertTo-Json -Compress))" - Invoke-RestMethod -Uri "$BaseUri/$($createRequestBody.requestName)" -Body ($publishRequestBody | ConvertTo-Json -Compress) -Verbose - displayName: Publish Symbols using internal REST API + Write-Host "Symbol publishing process via REST API completed for '$requestNameForUri'." From 9705b17ed15d48b6b02c9578209e80182a90ebc3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 29 May 2025 14:24:27 -0700 Subject: [PATCH 46/57] workaround for a VC++ bug in VS 17.14 (#24878) The following unit tests failed when building ONNX Runtime with Visual Studio 17.14 in Release or RelWithDebInfo configuration. - SparseTensorConversionTests.TestDenseToSparseConversion - MeanVarianceNormalizationTest.AllAxes - MVNContribOpTest.MeanVarianceNormalizationCPUTest_Version1_TO_8 This PR provides a workaround for the two MVN tests. --- onnxruntime/test/common/tensor_op_test_utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index acb520f894569..0ab3b693d59d9 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -133,7 +133,8 @@ inline std::vector ValueRange(size_t count, BFloat16 start, return result; } -inline std::pair MeanStdev(gsl::span v) { +template +inline std::pair MeanStdev(const T& v) { float sum = std::accumulate(v.begin(), v.end(), 0.0f); float mean = sum / v.size(); From f57db79743c4d1a3553aa05cf95bcd10966030e6 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Thu, 29 May 2025 16:09:09 -0700 Subject: [PATCH 47/57] change dependency from gitlab eigen to github eigen-mirror (#24884) Fix for https://github.com/microsoft/onnxruntime/issues/24861 --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- cmake/deps.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 728241840f723..6e045f6dcdc9d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -22,7 +22,9 @@ dlpack;https://github.com/dmlc/dlpack/archive/5c210da409e7f1e51ddf445134a4376fdb # it contains changes on top of 3.4.0 which are required to fix build issues. # Until the 3.4.1 release this is the best option we have. # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 -eigen;https://gitlab.com/libeigen/eigen/-/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;5ea4d05e62d7f954a46b3213f9b2535bdd866803 +# Moved to github mirror to avoid gitlab issues. +# Issue link: https://github.com/bazelbuild/bazel-central-registry/issues/4355 +eigen;https://github.com/eigen-mirror/eigen/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;05b19b49e6fbb91246be711d801160528c135e34 flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 From a6121187cb3fbd028bf52fa1c065eb55ed40197d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 May 2025 16:15:42 -0700 Subject: [PATCH 48/57] [ci] revise wasm CI (#24825) ### Description revise WASM CI to run test as later step than publishing artifacts. This allows download the binary to diagnose test failures. ### Motivation and Context --- .../linux-wasm-ci-build-and-test-workflow.yml | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index ef2f35aecdb3b..fc9bb53659442 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -65,27 +65,6 @@ jobs: --skip_tests working-directory: ${{ github.workspace }} - - name: Test (Node.js) (simd + threads) - # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. - if: ${{ inputs.build_config == 'Debug' }} - run: | - python ./tools/ci_build/build.py \ - ${{ env.common_build_args }} \ - --build_dir ${{ github.workspace }}/build/wasm_inferencing \ - --test - working-directory: ${{ github.workspace }} - - - name: Test (browser) (simd + threads) - # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. - if: ${{ inputs.build_config == 'Debug' }} - run: | - python ./tools/ci_build/build.py \ - ${{ env.common_build_args }} \ - --build_dir ${{ github.workspace }}/build/wasm_inferencing \ - --wasm_run_tests_in_browser \ - --test - working-directory: ${{ github.workspace }} - - name: Build (simd + threads + JSEP) if: ${{ inputs.build_jsep == true }} run: | @@ -143,6 +122,28 @@ jobs: name: ${{ inputs.build_config }}_wasm_webgpu path: ${{ github.workspace }}/artifacts/wasm_webgpu + - name: Test (Node.js) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --test + working-directory: ${{ github.workspace }} + + - name: Test (browser) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --wasm_run_tests_in_browser \ + --target onnxruntime_test_all \ + --update --build --test + working-directory: ${{ github.workspace }} + - name: Publish test results if: ${{ always() && inputs.build_config == 'Debug' }} uses: actions/upload-artifact@v4 From d520798be751e27222a16d92ee013966109bcee9 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Thu, 29 May 2025 18:52:36 -0700 Subject: [PATCH 49/57] [WebGPU EP] Fix NaN bug in softmax operator (#24855) Handle NaN in softmax operator for WebGPU EP and JSEP. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 4 +++- .../core/providers/webgpu/math/softmax.cc | 4 +++- .../test/providers/cpu/math/softmax_test.cc | 16 ++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 7c62d1f7182a7..2056416873df5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -152,7 +152,9 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // calculate final value for each element in the row for (var col = lindex; col < cols; col += wg) { - let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + var value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + // max operation protects against NaN since all values should be >=0 + value = max(value, ${valueType}(0.0)); setValue(row, col, row_stride, value); } }`; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 178ca0b9e0515..2f34aa21c8309 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -141,7 +141,9 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { // Calculate the final value for each element in the row << " for (var col = lindex; col < cols; col += wg) {\n" - << " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " var value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " // max operation protects against NaN since all values should be >=0\n" + << " value = max(value, x_value_t(0.0));\n" << " setValue(row, col, row_stride, value);\n" << " }\n"; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 1c6375ebdb0b1..d97873c21983f 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -49,6 +49,22 @@ TEST(SoftmaxOperator, Simple) { RunTest(x_vals, expected_vals, dimensions); } +#ifdef USE_WEBGPU +TEST(SoftmaxOperator, webgpu_nan) { + OpTester test("Softmax", 13); // axis default is -1 + + std::vector x_vals = {-INFINITY, -INFINITY, -INFINITY}; + std::vector expected_result = {0.0f, 0.0f, 0.0f}; + std::vector dimensions = {1, 3}; + + test.AddInput("X", dimensions, x_vals); + test.AddOutput("Y", dimensions, expected_result); + + // explicitly disable CPU EP for this test since CPU implementation does not handle NaN + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); +} +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_XNNPACK) TEST(SoftmaxOperator, Simple_fp16) { #ifdef USE_CUDA From 9c186713f13396d9752c97aa3738086c963a34b2 Mon Sep 17 00:00:00 2001 From: Ken Zhou <74420508+notken12@users.noreply.github.com> Date: Fri, 30 May 2025 00:44:46 -0400 Subject: [PATCH 50/57] Fix inference unable to run due to JS WASM runtime not being bundled into `onnxruntime-web/wasm` build (#24836) ### Description Fixes inference error from `ort-wasm-simd-threaded.mjs` not being bundled into `ort.wasm.bundle.min.mjs` as it is for other `bundle.min.mjs` builds. ### Motivation and Context To decrease my app's bundle size, I followed the [conditional importing guide](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js/importing_onnxruntime-web#conditional-importing) and imported the WASM-only build: ```diff - import * as ort from 'onnxruntime-web'; + import * as ort from 'onnxruntime-web/wasm'; ``` After this change, creating an inference session would result in: `TypeError: Failed to resolve module specifier './ort-wasm-simd-threaded.mjs'`. This was because `ort-wasm-simd-threaded.mjs` was not bundled into the build at `onnxruntime-web/wasm`, which points to `ort.wasm.bundle.min.mjs`, despite how its name suggests. In other builds with `bundle` in their name, the module is bundled, yet it was not done so in the WASM one. This PR bundles the Javascript WASM runtime in to match the other builds, fixing the error. --- js/web/script/build.ts | 7 ++++++- js/web/test/e2e/src/cjs-js/main.js | 20 +++++++++++--------- js/web/test/e2e/src/esm-js/main.js | 20 +++++++++++--------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6a9432c2b5acd..2ea883f739c52 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -644,7 +644,12 @@ async function main() { isProduction: true, outputName: 'ort.wasm.bundle', format: 'esm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true', + }, }); // ort.webgl[.min].[m]js await addAllWebBuildTasks({ diff --git a/js/web/test/e2e/src/cjs-js/main.js b/js/web/test/e2e/src/cjs-js/main.js index c9b8d3e85455d..5eea342fdcae7 100644 --- a/js/web/test/e2e/src/cjs-js/main.js +++ b/js/web/test/e2e/src/cjs-js/main.js @@ -6,13 +6,15 @@ const ort = require('onnxruntime-web/wasm'); const { setupMultipleThreads, testInferenceAndValidate } = require('./shared'); -if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][commonjs]', async function () { - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); -} else { - it('Browser package consuming test - multi-thread - [js][commonjs]', async function () { - setupMultipleThreads(ort); - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); +if (typeof it !== 'undefined') { + if (typeof SharedArrayBuffer === 'undefined') { + it('Browser package consuming test - single-thread - [js][commonjs]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } else { + it('Browser package consuming test - multi-thread - [js][commonjs]', async function () { + setupMultipleThreads(ort); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } } diff --git a/js/web/test/e2e/src/esm-js/main.js b/js/web/test/e2e/src/esm-js/main.js index 7687b8b731878..54744a2a4b16f 100644 --- a/js/web/test/e2e/src/esm-js/main.js +++ b/js/web/test/e2e/src/esm-js/main.js @@ -6,13 +6,15 @@ import * as ort from 'onnxruntime-web/wasm'; import { setupMultipleThreads, default as testInferenceAndValidate } from './shared.js'; -if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][esm]', async function () { - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); -} else { - it('Browser package consuming test - multi-thread - [js][esm]', async function () { - setupMultipleThreads(ort); - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); +if (typeof it !== 'undefined') { + if (typeof SharedArrayBuffer === 'undefined') { + it('Browser package consuming test - single-thread - [js][esm]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } else { + it('Browser package consuming test - multi-thread - [js][esm]', async function () { + setupMultipleThreads(ort); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } } From cd9d5fce39e54839a1f67b0e408865dfd1ab3698 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 29 May 2025 21:56:52 -0700 Subject: [PATCH 51/57] Bump ruff from 0.11.10 to 0.11.11 (#24859) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.11.10 to 0.11.11.
Release notes

Sourced from ruff's releases.

0.11.11

Release Notes

Preview features

  • [airflow] Add autofixes for AIR302 and AIR312 (#17942)
  • [airflow] Move rules from AIR312 to AIR302 (#17940)
  • [airflow] Update AIR301 and AIR311 with the latest Airflow implementations (#17985)
  • [flake8-simplify] Enable fix in preview mode (SIM117) (#18208)

Bug fixes

  • Fix inconsistent formatting of match-case on [] and _ (#18147)
  • [pylint] Fix PLW1514 not recognizing the encoding positional argument of codecs.open (#18109)

CLI

  • Add full option name in formatter warning (#18217)

Documentation

  • Fix rendering of admonition in docs (#18163)
  • [flake8-print] Improve print/pprint docs for T201 and T203 (#18130)
  • [flake8-simplify] Add fix safety section (SIM110,SIM210) (#18114,#18100)
  • [pylint] Fix docs example that produced different output (PLW0603) (#18216)

Contributors

... (truncated)

Changelog

Sourced from ruff's changelog.

0.11.11

Preview features

  • [airflow] Add autofixes for AIR302 and AIR312 (#17942)
  • [airflow] Move rules from AIR312 to AIR302 (#17940)
  • [airflow] Update AIR301 and AIR311 with the latest Airflow implementations (#17985)
  • [flake8-simplify] Enable fix in preview mode (SIM117) (#18208)

Bug fixes

  • Fix inconsistent formatting of match-case on [] and _ (#18147)
  • [pylint] Fix PLW1514 not recognizing the encoding positional argument of codecs.open (#18109)

CLI

  • Add full option name in formatter warning (#18217)

Documentation

  • Fix rendering of admonition in docs (#18163)
  • [flake8-print] Improve print/pprint docs for T201 and T203 (#18130)
  • [flake8-simplify] Add fix safety section (SIM110,SIM210) (#18114,#18100)
  • [pylint] Fix docs example that produced different output (PLW0603) (#18216)
Commits
  • 0397682 Bump 0.11.11 (#18259)
  • bcefa45 [ty] Rename call-possibly-unbound-method to `possibly-unbound-implicit-call...
  • 91b7a57 [ty] Implement Python's floor division semantics for Literal ints (#18249)
  • 98da200 [ty] Fix server panic when calling system_mut (#18252)
  • 029085f [ty] Clarify ty check output default in documentation. (#18246)
  • 6df10c6 [pylint] Fix docs example that produced different output (PLW0603) (#18216)
  • bdf4884 Preserve tuple parentheses in case patterns (#18147)
  • 01eeb2f [ty] Support frozen dataclasses (#17974)
  • cb04343 [ty] Split invalid-base error code into two error codes (#18245)
  • 02394b8 [ty] Improve invalid-type-form diagnostic where a module-literal type is us...
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.11.10&new-version=0.11.11)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-lintrunner.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index e66ec3bb58d74..ed3471bdb47a9 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,6 +3,6 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 # RUFF -ruff==0.11.10 +ruff==0.11.11 # CLANGFORMAT clang-format==19.1.7 From 9d6546e68a81c31bd19571b187d922317253f602 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 30 May 2025 10:55:08 -0700 Subject: [PATCH 52/57] [CUDA] fp16 intB gemm (#24854) ### Description * Add fpA intB gemm kernel from WeightOnlyGroupwiseQuantGemmPlugin of TensorRT-LLM. * Add prepacking to convert weight/scales/zero_points to adapt MatMulNBits to use the kernel. Limitations: * Only enable fp16 kernel. BF16 support will be added later. * Requires zero points. The support of scales only might be added later. * Bias is not enabled since previous MatMulNBits kernel does not support bias. ### Motivation and Context To improve performance of LLM. Initial result shows 2.2x throughput on prompt processing and 1.25X throughput on token generation using onnxruntime-genai benchmark_e2e.py on phi-4-mini-instruct on A100. --- cmake/CMakeLists.txt | 35 +- cmake/external/cuda_configuration.cmake | 172 +++ cmake/onnxruntime_providers_cuda.cmake | 7 +- cmake/utils/detect_cuda_arch.cu | 39 + .../cuda/llm/common/cuda_runtime_utils.h | 46 + .../contrib_ops/cuda/llm/common/logger.h | 18 + .../contrib_ops/cuda/llm/common/workspace.h | 75 + .../cuda/llm/cutlass_extensions/arch/mma.h | 97 ++ .../cutlass_extensions/compute_occupancy.h | 74 + .../epilogue/thread/fused_activations.h | 61 + .../llm/cutlass_extensions/epilogue_helpers.h | 122 ++ .../sm90_gmma_builder_interleaved.inl | 140 ++ .../collective_builder_interleaved.hpp | 55 + .../collective/collective_mma_interleaved.hpp | 55 + ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 1372 +++++++++++++++++ .../gemm/device/gemm_universal_base_compat.h | 370 +++++ .../gemm/kernel/default_fpA_intB_traits.h | 149 ++ .../gemm/kernel/default_int8_traits.h | 51 + .../gemm/kernel/fpA_intB_gemm.h | 461 ++++++ .../gemm/kernel/gemm_with_epilogue_visitor.h | 451 ++++++ .../gemm/kernel/mixed_gemm_B_layout.h | 112 ++ .../gemm/threadblock/default_dq_mma.h | 117 ++ .../threadblock/default_dq_mma_multistage.h | 289 ++++ .../threadblock/default_dq_mma_pipelined.h | 270 ++++ .../gemm/threadblock/default_mma.h | 336 ++++ .../gemm/threadblock/default_mma_bf16.h | 336 ++++ .../gemm/threadblock/dq_mma_base.h | 211 +++ .../gemm/threadblock/dq_mma_multistage.h | 93 ++ .../dq_mma_multistage_finegrained.h | 612 ++++++++ .../gemm/threadblock/dq_mma_pipelined.h | 89 ++ .../dq_mma_pipelined_finegrained.h | 431 ++++++ .../gemm/warp/default_mma_tensor_op.h | 89 ++ .../warp/mma_tensorop_compute_B_with_f16.h | 263 ++++ .../gemm/warp/mma_tensorop_dequantizer.h | 393 +++++ .../llm/cutlass_extensions/gemm_configs.h | 405 +++++ .../interleaved_numeric_conversion.h | 399 +++++ .../tile_interleaved_layout.h | 48 + .../fine_grained_scale_zero_iterator.h | 218 +++ .../cutlass_extensions/weight_only_quant_op.h | 41 + .../contrib_ops/cuda/llm/cutlass_heuristic.cc | 479 ++++++ .../contrib_ops/cuda/llm/cutlass_heuristic.h | 50 + .../cuda/llm/cutlass_preprocessors.cc | 687 +++++++++ .../cuda/llm/cutlass_preprocessors.h | 75 + .../cuda/llm/cutlass_type_conversion.h | 146 ++ .../bf16_int4_gemm_scale_zeros.cu | 26 + .../bf16_int8_gemm_scale_zeros.cu | 26 + .../fp16_int4_gemm_scale_zeros.cu | 26 + .../fp16_int8_gemm_scale_zeros.cu | 25 + .../cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h | 135 ++ .../fpA_intB_gemm/fpA_intB_gemm_template.h | 489 ++++++ .../fpA_intB_gemm_template_sm90.h | 244 +++ .../fpA_intB_gemm_launcher_1.generated.cu | 264 ++++ .../fpA_intB_gemm_launcher_2.generated.cu | 516 +++++++ .../launchers/fpA_intB_launcher_sm90.h | 36 + .../launchers/fpA_intB_launcher_sm90.inl | 282 ++++ .../cuda/llm/fpA_intB_gemm_adaptor.cu | 260 ++++ .../cuda/llm/fpA_intB_gemm_adaptor.h | 43 + .../cuda/llm/fpA_intB_gemm_profiler.cc | 100 ++ .../cuda/llm/fpA_intB_gemm_profiler.h | 86 ++ .../cuda/llm/fpA_intB_gemv/details.h | 239 +++ .../cuda/llm/fpA_intB_gemv/dispatcher.h | 423 +++++ .../llm/fpA_intB_gemv/dispatcher_bf16_int4.cu | 32 + .../dispatcher_bf16_int4_hopper.cu | 31 + .../llm/fpA_intB_gemv/dispatcher_bf16_int8.cu | 28 + .../dispatcher_bf16_int8_hopper.cu | 28 + .../llm/fpA_intB_gemv/dispatcher_fp16_int4.cu | 32 + .../dispatcher_fp16_int4_hopper.cu | 31 + .../llm/fpA_intB_gemv/dispatcher_fp16_int8.cu | 28 + .../dispatcher_fp16_int8_hopper.cu | 28 + .../cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu | 96 ++ .../cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h | 79 + .../contrib_ops/cuda/llm/gemm_profiler.cc | 311 ++++ .../contrib_ops/cuda/llm/gemm_profiler.h | 283 ++++ .../contrib_ops/cuda/llm/generate_kernels.py | 397 +++++ .../contrib_ops/cuda/llm/nv_infer_datatype.h | 59 + .../cuda/quantization/matmul_nbits.cc | 316 +++- .../cuda/quantization/matmul_nbits.h | 91 ++ .../providers/cuda/shared_inc/cuda_call.h | 30 +- .../test/contrib_ops/matmul_4bits_test.cc | 54 +- .../test/contrib_ops/matmul_8bits_test.cc | 40 +- 80 files changed, 15173 insertions(+), 80 deletions(-) create mode 100644 cmake/external/cuda_configuration.cmake create mode 100644 cmake/utils/detect_cuda_arch.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/common/logger.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/common/workspace.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu create mode 100644 onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc create mode 100644 onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h create mode 100644 onnxruntime/contrib_ops/cuda/llm/generate_kernels.py create mode 100644 onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 08aed0cb296a2..301fb0fbe82b0 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -859,6 +859,10 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) if (onnxruntime_USE_CUDA) + include(cuda_configuration) + setup_cuda_compiler() + setup_cuda_architectures() + enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -878,9 +882,6 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) - message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") - endif() if (WIN32) message( STATUS "Lean Attention unsupported in Windows") set(onnxruntime_USE_LEAN_ATTENTION OFF) @@ -1590,25 +1591,17 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) endif() find_package(CUDAToolkit REQUIRED) - if (NOT CMAKE_CUDA_ARCHITECTURES) - # Note that we generate SASS+PTX code for specified cuda architectures by assigning "xy" - # To add SASS only, assign "xy-real" - # To add PTX only, assign "xy-virtual" - if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") - # Support for Jetson/Tegra ARM devices - set(CMAKE_CUDA_ARCHITECTURES "53-real;62-real;72-real;87") # TX1/Nano, TX2, Xavier, Orin - else() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) - # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. - set(CMAKE_CUDA_ARCHITECTURES "37-real;50-real;52-real;60-real;70-real;75-real;80-real;86-real;89") - elseif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) - set(CMAKE_CUDA_ARCHITECTURES "52-real;60-real;70-real;75-real;80-real;86-real;89-real;90") - else() - # https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html - set(CMAKE_CUDA_ARCHITECTURES "all") # Supporting all, including latest Blackwell B series & RTX 50 series - endif() - endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8) + add_definitions("-DENABLE_FP8") + message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag") endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + add_definitions("-DENABLE_FP4") + message(STATUS "CUDA Toolkit version is greater or equal than 12.8, enable -DENABLE_FP4 flag") + endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch") diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake new file mode 100644 index 0000000000000..ef94ec25132e3 --- /dev/null +++ b/cmake/external/cuda_configuration.cmake @@ -0,0 +1,172 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# + +macro(setup_cuda_compiler) + # Determine CUDA version before enabling the language extension check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER + # if CMAKE_CUDA_COMPILER is not set + include(CheckLanguage) + if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER) + set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER}) + endif() + check_language(CUDA) + if(CMAKE_CUDA_HOST_COMPILER_BACKUP) + set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CUDA_HOST_COMPILER_BACKUP}) + check_language(CUDA) + endif() + if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}") + if(NOT WIN32) # Linux + execute_process( + COMMAND "bash" "-c" "${CMAKE_CUDA_COMPILER} --version | grep -E -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-" + RESULT_VARIABLE _BASH_SUCCESS + OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT _BASH_SUCCESS EQUAL 0) + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + + else() # Windows + execute_process( + COMMAND ${CMAKE_CUDA_COMPILER} --version + OUTPUT_VARIABLE versionString + RESULT_VARIABLE versionResult) + + if(versionResult EQUAL 0 AND versionString MATCHES "V[0-9]+\\.[0-9]+\\.[0-9]+") + string(REGEX REPLACE "V" "" version ${CMAKE_MATCH_0}) + set(CMAKE_CUDA_COMPILER_VERSION "${version}") + else() + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + endif() + else() + message(FATAL_ERROR "No CUDA compiler found") + endif() + + set(CUDA_REQUIRED_VERSION "11.4") + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION) + message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}") + endif() +endmacro() + +macro(setup_cuda_architectures) + # cmake-format: off + # Initialize and normalize CMAKE_CUDA_ARCHITECTURES before enabling CUDA. + # Special values: + # (1) `native` is resolved to HIGHEST available architecture. Fallback to `all` if detection failed. + # (2) `all` / `all-major` / unset is resolved to a default set of architectures we optimized and compiler supports. + # Numerical architectures: + # * For `-virtual` architectures, the last one is kept as it is, and the others are ignored. + # * `-real` suffix is automatically added for other cases. + # * Always use accelerated (`-a` suffix) target for supported real architectures. + # cmake-format: on + + if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native") + # Detect highest available compute capability + set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch) + set(CUDAFILE ${CMAKE_SOURCE_DIR}/utils/detect_cuda_arch.cu) + execute_process(COMMAND ${CMAKE_CUDA_COMPILER} -lcuda ${CUDAFILE} -o ${OUTPUTFILE}) + message(VERBOSE "Detecting native CUDA compute capability") + execute_process( + COMMAND ${OUTPUTFILE} + RESULT_VARIABLE CUDA_RETURN_CODE + OUTPUT_VARIABLE CUDA_ARCH_OUTPUT) + if(NOT ${CUDA_RETURN_CODE} EQUAL 0) + message(WARNING "Detecting native CUDA compute capability - fail") + message(WARNING "CUDA compute capability detection failed, compiling for all optimized architectures") + unset(CMAKE_CUDA_ARCHITECTURES) + else() + message(STATUS "Detecting native CUDA compute capability - done") + set(CMAKE_CUDA_ARCHITECTURES "${CUDA_ARCH_OUTPUT}") + endif() + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all") + unset(CMAKE_CUDA_ARCHITECTURES) + message(STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all enables a list of architectures OnnxRuntime optimized for, " + "not all architectures CUDA compiler supports.") + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all-major") + unset(CMAKE_CUDA_ARCHITECTURES) + message( + STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all-major enables a list of architectures OnnxRuntime optimized for, " + "not all major architectures CUDA compiler supports.") + else() + message(STATUS "Original CMAKE_CUDA_ARCHITECTURES : ${CMAKE_CUDA_ARCHITECTURES}") + endif() + + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if(CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") + # Support for Jetson/Tegra ARM devices + set(CMAKE_CUDA_ARCHITECTURES "53;62;72;87") # TX1/Nano, TX2, Xavier, Orin + else() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) + # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. + set(CMAKE_CUDA_ARCHITECTURES "37;50;52;60;70;75;80;86;89") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) + set(CMAKE_CUDA_ARCHITECTURES "52;60;70;75;80;86;89;90") + else() + set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;89;90;100;120") + endif() + endif() + endif() + + unset(CMAKE_CUDA_ARCHITECTURES_CLEAN) + unset(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if(CUDA_ARCH STREQUAL "") + continue() + endif() + + if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$") + set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + else() + message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}") + endif() + endforeach() + list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN) + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_CLEAN}) + + # CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without automatically added -real or -a suffix. + set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") + + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120") + foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) + if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}") + message(STATUS "Excluding SM ${CUDA_ARCH}") + endif() + endforeach() + + # Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90. + set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120") + unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real") + else() + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real") + endif() + endforeach() + + if(DEFINED CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL}") + endif() + + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NORMALIZED}) + + message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") +endmacro() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 6a7510a5d83bc..da46f29dacf5f 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -179,7 +179,7 @@ set(onnxruntime_NVCC_THREADS "1" CACHE STRING "Number of threads that NVCC can use for compilation.") target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() - + # Since CUDA 12.8, compiling diagnostics become stricter if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) target_compile_options(${target} PRIVATE "$<$:--relocatable-device-code=true>") @@ -261,6 +261,11 @@ set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") + if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + target_compile_options(${target} PRIVATE $<$:-Xptxas=-w>) + target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS) + endif() + if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling target_link_libraries(${target} PRIVATE CUDA::cupti) endif() diff --git a/cmake/utils/detect_cuda_arch.cu b/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 0000000000000..83fbc13dbff7f --- /dev/null +++ b/cmake/utils/detect_cuda_arch.cu @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include + +int main(int argc, char* argv[]) +{ + int n_devices = 0; + int rc = cudaGetDeviceCount(&n_devices); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + + std::vector> arch(n_devices); + for (int cd = 0; cd < n_devices; ++cd) + { + cudaDeviceProp dev; + int rc = cudaGetDeviceProperties(&dev, cd); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + else + { + arch[cd] = {dev.major, dev.minor}; + } + } + + std::pair best_cc = *std::max_element(begin(arch), end(arch)); + std::cout << best_cc.first << best_cc.second; + + return 0; +} diff --git a/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h new file mode 100644 index 0000000000000..06442c6e02ae0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace onnxruntime::llm::common { +inline int getDevice() { + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + return deviceID; +} + +inline int getSMVersion() { + int device{-1}; + CUDA_CALL_THROW(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMultiProcessorCount() { + int nSM{0}; + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID)); + return nSM; +} +} // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/common/logger.h b/onnxruntime/contrib_ops/cuda/llm/common/logger.h new file mode 100644 index 0000000000000..a3992e751926d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/logger.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +#ifndef NDEBUG +#define ORT_LLM_LOG_TRACE(msg) LOGS_DEFAULT(VERBOSE) << msg +#define ORT_LLM_LOG_DEBUG(msg) LOGS_DEFAULT(VERBOSE) << msg +#else +#define ORT_LLM_LOG_TRACE(msg) +#define ORT_LLM_LOG_DEBUG(msg) +#endif + +#define ORT_LLM_LOG_INFO(msg) LOGS_DEFAULT(INFO) << msg +#define ORT_LLM_LOG_WARNING(msg) LOGS_DEFAULT(WARNING) << msg +#define ORT_LLM_LOG_ERROR(msg) LOGS_DEFAULT(ERROR) << msg diff --git a/onnxruntime/contrib_ops/cuda/llm/common/workspace.h b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h new file mode 100644 index 0000000000000..126884a941336 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace onnxruntime::llm::common { + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return reinterpret_cast(addr); +} + +constexpr size_t alignSize(size_t size, size_t to) { + if ((size % to) != 0U) { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) { + uintptr_t addr = (uintptr_t)ptr; + addr += previousWorkspaceSize; + return alignPtr(reinterpret_cast(addr), alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) { + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) { + size_t total = 0; + for (int i = 0; i < count; i++) { + total += workspaces[i]; + if (workspaces[i] % alignment) { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000..6de056b44339d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000000..63dca2f458e1a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } else { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } else { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } else { + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 0000000000000..e0911460ef8a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000000..1d7ff42d591e2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +struct EpilogueOpBiasSilu { +}; + +struct EpilogueOpBiasReLU { +}; + +struct EpilogueOpBiasFtGelu { +}; + +struct EpilogueOpBias { +}; + +struct EpilogueOpDefaultSilu { +}; + +struct EpilogueOpDefaultReLU { +}; + +struct EpilogueOpDefaultFtGelu { +}; + +struct EpilogueOpDefault { +}; + +template +struct Epilogue { + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl new file mode 100644 index 0000000000000..a7146d99224eb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS Mixed Scaled GEMM +template +struct CollectiveBuilderInterleaved + || cute::is_same_v + || cute::is_same_v)>> +{ + +private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>; + static constexpr bool NeitherIsTuple + = !cute::is_tuple::value && !cute::is_tuple::value; + +public: + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value + || (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + "Either A OR B must be a tuple or the widths of A and B must be different."); + + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; + + using GmemLayoutATag = GmemLayoutATag_; + using GmemLayoutBTag = GmemLayoutBTag_; + + using ElementPairA = cute::conditional_t, ElementPairA_>; + using ElementPairB = cute::conditional_t, ElementPairB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B(); + static_assert(!IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. + static constexpr bool SwapAB = !IsATransformed; + + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + using ElementMma = cute::conditional_t; + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma + = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using DispatchPolicy + = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideA = TagToStrideA_t; + using StrideB = TagToStrideB_t; + + using CollectiveOp = CollectiveMmaInterleaved; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp new file mode 100644 index 0000000000000..97feaa2498bba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveBuilderInterleaved { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp new file mode 100644 index 0000000000000..ce56a9d717ceb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMmaInterleaved { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000..499504439aa46 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1372 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template +struct CollectiveMmaInterleaved, + TileShape_, ElementAOptionalTuple, StrideA_, ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, + SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { + private: + template + static constexpr auto get_logical_ptr(PointerType const* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return subbyte_iterator(ptr); + } else { + return ptr; + } + } + + template + static constexpr auto get_smem_interleave_layout() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout(TileShape{})), Shape<_4, _4, _2, _4>>, + Stride<_128, Stride<_1, _8, _4, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _4, _2>>, + Stride<_64, Stride<_1, _8, _2, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _2, _4>>, + Stride<_64, Stride<_1, _4, _2, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + + public: + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + using TileShape = TileShape_; + + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is + // void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert((IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || (cutlass::gemm::detail::is_k_major() && cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + // Scale layout atom set after swapping. + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealInternalElementA = cute::conditional_t; + using RealInternalElementB = cute::conditional_t; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(InternalSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + static constexpr int type_factor = sizeof_bits::value / sizeof_bits::value; + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + using Layout_Interleave = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + get_smem_interleave_layout())); + using SmemLayoutA_mma_interleave = decltype(tile_to_shape(Layout_Interleave{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + using SmemLayoutA_mma = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + make_layout(make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + make_stride(get<2>(TileShape{}), _1{}, get<0>(TileShape{}) * get<2>(TileShape{}))))); + // cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), + // Stride<_1, cute::Int(TileShape{})>, cute::Int(TileShape{}) * + // get<2>(TileShape{})>>, Stride(TileShape{})>, _1, + // cute::Int(TileShape{}) * get<2>(TileShape{})>>>{}))); + + using SmemLayoutB = decltype(tile_to_shape(InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideB>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape(SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, NonVoidStrideScale>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(!cute::is_base_of::value && cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + public: + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 && SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + private: + using Outer = CollectiveMmaInterleaved; + + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideA{}, static_cast(0)), InternalStrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_Scale = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideB{}, static_cast(0)), InternalStrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M, K, L), dA)); + Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N, K, L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, + TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + ElementScale const* ptr_S = args.ptr_S; + StrideScale dS = args.dS; + Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_scale = make_tma_copy(GmemTiledCopyScale{}, tensor_scale, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_zero = make_tma_copy(GmemTiledCopyScale{}, tensor_zero, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (ModeHasScales) { + int const scale_mn = SwapAB ? N : M; + int const scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof...(Ts) == 4, "Scaled and zero convert needs four inputs"); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs( + mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify + // tma transaction bytes on the fly. We must do a ceiling divide here to correctly handle with + // group_size == K. In that case, we don't require that K is a multiple of the threadblock tile K + int const ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + int const scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_, _, _, scale_load_k), + tSsS(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), + tZgZ(_, _, _, scale_load_k), tZsZ(_, _, _, write_stage)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + template + constexpr auto interleave_for_mixed_input() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_16, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2>>, + Stride, _0, Stride<_16>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_8, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor( + make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA_mma_interleave{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto interleave_layout = interleave_for_mixed_input(); + + auto interleave_remapping = cute::flat_product(interleave_layout, Layout>>{}); + + Tensor tCsA_remapped = tCsA.compose(interleave_remapping); + + auto interleave_remapping_thread = right_inverse(interleave_layout); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_, _, Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = make_fragment_like(tCrA_mma); + + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Compute the max vector length that can be used to copy A. This will match the vector width of the + // conversions used. It helps by allowing the compiler to convert using the same register that was used + // to load the data from smem. This significantly reduces the need to move data among registers. + // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform + // the conversion does not impact correctness. + using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); + using A_CPY_VEC_remapped = decltype(max_common_vector(tCsA_remapped, tCrA_copy_view)); + static_assert(A_CPY_VEC_remapped{} == 32 / cutlass::sizeof_bits::value, + "max_common_vector(tCsA_remapped, tCrA_copy_view) is 32 / cutlass::sizeof_bits::value"); + auto tCrA_mma_tmp = tCrA_mma.compose(interleave_remapping_thread); + auto tCrA_mma_inverse_mapping = tCrA_mma_tmp.compose(tCrA_mma.layout()); + + auto tCrA_load_tmp = tCrA_load.compose(interleave_remapping_thread); + auto tCrA_load_inverse_mapping = tCrA_load_tmp.compose(tCrA_load.layout()); + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + constexpr int kNumKIterationsPerWarpBLoad = type_factor / 2; + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, read_stage, kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, read_stage, kNumKIterationsPerWarpBLoad); + } + + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + if (k_block < K_BLOCK_MAX - 1) { + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for + // the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + warpgroup_wait(); + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this + // stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, + kNumKIterationsPerWarpBLoad); + } else { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) // release prior barrier + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + + if (k_block < K_BLOCK_MAX - 1) { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 1, read_stage, kNumKIterationsPerWarpBLoad); + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + private: + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE auto partition_extra_tma_inputs(Params const& mainloop_params, cute::tuple const& load_inputs, + TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, int const m_coord, int const l_coord) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } + + template + constexpr auto scale_remapping() { + if constexpr (cute::sizeof_bits_v == 8) { + return Layout, Stride<_1, _8, _4>>{}; + } else if constexpr (cute::sizeof_bits_v == 16) { + return Layout, Stride<_1, _4, _2>>{}; + } else { + static_assert(dependent_false, "cute::sizeof_bits_v must be 8 or 16"); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + auto remappingScale = scale_remapping(); + Tensor tCsS_remapped = tCsS.compose(remappingScale, _, _, _); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS_remapped, tCrS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCsZ_remapped = tCsZ.compose(remappingScale, _, _, _); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).shape()); + return cute::make_tuple(tCsS_remapped, tCrS, tCsZ_remapped, tCrZ); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE auto retile_extra_mma_info( + TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, int const warp_group_thread_idx) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE void copy_A_and_extra_info(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, int read_stage, int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad == 1) { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + } else { + using reshape_layout = Layout, Int<1>, Int<2>>>; + auto tCrA_copy_view_reshaped = tCrA_copy_view.compose(reshape_layout{}); + if (k_block % kNumKIterationsPerWarpBLoad == 0) + copy(smem_tiled_copy_A, tCsA(_, _, k_block / kNumKIterationsPerWarpBLoad, read_stage), + tCrA_copy_view_reshaped(_, _, k_block / kNumKIterationsPerWarpBLoad)); + } + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + /// Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock(TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, int const k_block, + int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad != 1) { + if (k_block % kNumKIterationsPerWarpBLoad == 0) { + int k_block_load = k_block / kNumKIterationsPerWarpBLoad; + using reshape_layout = Layout, _1, _2>>; + auto tCrA_load_reshaped = tCrA_load.compose(reshape_layout{}); + auto tCra_mma_reshaped = tCrA_mma.compose(reshape_layout{}); + + using scale_reshape = Layout, _1, _1>, Stride, _0, _0>>; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A( + tCrA_load_reshaped(_, _, k_block_load), vec_A, tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto tCrZ_reshaped = tCrZ.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCrZ_reshaped(_, _, 0), tCra_mma_reshaped(_, _, k_block_load)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrZ = cute::get<3>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrZ(_, _, 0), + tCrA_mma(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } + + /// Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor(Tensor const& in, Tensor& out, + cute::Int width = {}) { + /// This is an element-wise conversion where we expect both tensors to have the same layout. + /// As a result, we can cast as a cutlass array to use the fast numeric converters without + /// worrying about indexing into the layout. + constexpr int N = cosize_v; + + /// The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + using Converter = std::conditional_t < cutlass::sizeof_bits_v, + cutlass::FastInterleavedAndBiasedNumericArrayConverter, + cutlass::NumericArrayConverter>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& out) { + convert_tensor(in, out, a_vec_width); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hmul(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)))); + } else { + converted_inputs(j, i) *= scales(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& zeros, + Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + static_assert(cute::is_same_v, + "Type of the engine zero buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hfma(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(zeros(j, i)))); + } else { + converted_inputs(j, i) = converted_inputs(j, i) * scales(j, i) + zeros(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 0000000000000..c7f2a682323a0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,370 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000000..83ebe2191717b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,149 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MixedGemmArchTraits { + static_assert(dependent_false, "Unrecognized parameterization"); +}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 0000000000000..fe4bc0940d9e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000..a888ea3e71487 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), ref_zero(ref_zero), ref_C(ref_C), ref_D(ref_D), batch_count(serial_split_k_factor), output_op(output_op), gather_A_indices(gather_A_indices), gather_B_indices(gather_B_indices), scatter_D_indices(scatter_D_indices) { + } + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), group_size(args.group_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), ref_B(args.ref_B), params_scale(args.ref_scale.layout()), ref_scale(args.ref_scale), ref_zero(args.ref_zero), params_C(args.ref_C.layout()), ref_C(args.ref_C), params_D(args.ref_D.layout()), ref_D(args.ref_D), output_op(args.output_op), semaphore(static_cast(workspace)), gemm_k_size(gemm_k_size), gather_A_indices(args.gather_A_indices), gather_B_indices(args.gather_B_indices), scatter_D_indices(args.scatter_D_indices) { + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) { + static int const alignmentA = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const alignmentB = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const alignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const alignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, alignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, alignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& /*args*/, cutlass::gemm::GemmCoord const& /*grid_tiled_shape*/) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 1000) + // Use SM80 implementation for GB10x, GB20x. + run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000..163a43238a425 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = onnxruntime::llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm), batch_count(1) { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), problem_size(problem_size_), batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), quant_option(quant_option_), ref_alpha_col(ref_alpha_col_), ref_alpha_row(ref_alpha_row_), ref_C(ref_C_), ref_D(ref_D_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_D(0), epilogue_visitor(epilogue_visitor_) { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), params_A(0), params_B(0), params_alpha_col(0), params_C(0), params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), ptr_alpha_col(nullptr), ptr_alpha_row(nullptr), ptr_C(nullptr), ptr_D(nullptr), batch_stride_A(0), batch_stride_B(0) { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), params_alpha_col(args.ref_alpha_col.layout()), params_alpha_row(args.ref_alpha_col.layout()), params_C(args.ref_C.layout()), params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), quant_option(args.quant_option), ptr_alpha_col(args.ref_alpha_col.data()), ptr_alpha_row(args.ref_alpha_row.data()), ptr_C(args.ref_C.data()), ptr_D(args.ref_D.data()), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000..c0656ac784830 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,112 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB { +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000..ef28dcc46cd21 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,117 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters { +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000..8d73329ed7713 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,289 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> { + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000..ae0cee20d3575 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,270 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> { + private: + using SmemScaleType = half_t; + + public: + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> { + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000..dfe99c271f547 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000..cb5ce0f72b362 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + using MmaElementA = bfloat16_t; + using MmaElementB = bfloat16_t; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000..cad280febbe76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000..78b6abb50513f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 0000000000000..5db74039469c4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,612 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + warp_mma( + accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000..e992915cafeea --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 0000000000000..b362195834c87 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + warp_mma(accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000..e680493cf060a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000..21c787e91be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && platform::is_same::value) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000..47f1bb240e8b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,393 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h new file mode 100644 index 0000000000000..e48ef3f154883 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h @@ -0,0 +1,405 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "cute/tensor.hpp" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class CutlassTileConfigSM100 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + /* + * Grouped GEMM + */ + // M=64 + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // M=128 + CtaShape128x8x256B, + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape128x128x256B, + CtaShape128x256x256B, + + // M=256 + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B, +}; + +enum class MainloopScheduleType { + AUTO, // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. + PINGPONG, + COOPERATIVE, + WARPSPECIALIZED +}; + +#if 0 +static auto get_mainloop_schedule_name(MainloopScheduleType schedule) { + if (schedule == MainloopScheduleType::AUTO) { + return "auto"; + } else if (schedule == MainloopScheduleType::PINGPONG) { + return "pingpong"; + } else if (schedule == MainloopScheduleType::COOPERATIVE) { + return "cooperative"; + } else if (schedule == MainloopScheduleType::WARPSPECIALIZED) { + return "warpspecialized"; + } + return "unknown schedule"; +} +#endif + +enum class EpilogueScheduleType { + AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main + // loop. +}; + +enum class TileShape { + TileShape_64x16x128, + TileShape_64x32x128, + TileShape_64x64x128, + TileShape_64x128x128, + TileShape_64x256x128, + TileShape_64x512x128, + TileShape_128x16x128, + TileShape_128x32x128, + TileShape_128x64x128, + TileShape_128x128x128, + TileShape_128x256x128 +}; + +template +constexpr auto get_tile_shape() { + using namespace cute; + if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) { + return cute::Shape<_64, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) { + return cute::Shape<_64, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) { + return cute::Shape<_64, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) { + return cute::Shape<_64, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) { + return cute::Shape<_64, _256, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) { + return cute::Shape<_64, _512, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) { + return cute::Shape<_128, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) { + return cute::Shape<_128, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) { + return cute::Shape<_128, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) { + return cute::Shape<_128, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { + return cute::Shape<_128, _256, _128>{}; + } +} + +#if 0 +static auto get_tile_shape_name(TileShape Shape_MNK) { + if (Shape_MNK == TileShape::TileShape_64x16x128) { + return "64x16x128"; + } else if (Shape_MNK == TileShape::TileShape_64x32x128) { + return "64x32x128"; + } else if (Shape_MNK == TileShape::TileShape_64x64x128) { + return "64x64x128"; + } else if (Shape_MNK == TileShape::TileShape_64x128x128) { + return "64x128x128"; + } else if (Shape_MNK == TileShape::TileShape_64x256x128) { + return "64x256x128"; + } else if (Shape_MNK == TileShape::TileShape_64x512x128) { + return "64x512x128"; + } else if (Shape_MNK == TileShape::TileShape_128x16x128) { + return "128x16x128"; + } else if (Shape_MNK == TileShape::TileShape_128x32x128) { + return "128x32x128"; + } else if (Shape_MNK == TileShape::TileShape_128x64x128) { + return "128x64x128"; + } else if (Shape_MNK == TileShape::TileShape_128x128x128) { + return "128x128x128"; + } else if (Shape_MNK == TileShape::TileShape_128x256x128) { + return "128x256x128"; + } + return "Unknown shape"; +} +#endif + +enum class ClusterShape { + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x4x1, + ClusterShape_4x2x1, + ClusterShape_2x4x1, + ClusterShape_4x4x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +#if 0 +static auto get_cluster_shape_name(ClusterShape Shape_MNK) { + if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return "1x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return "2x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return "1x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return "2x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return "1x8x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return "8x1x1"; + } + return "Unknown shape"; +} + +template +constexpr auto get_cluster_shape() { + using namespace cute; + if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return cute::Shape<_1, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return cute::Shape<_2, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return cute::Shape<_1, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return cute::Shape<_1, _8, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return cute::Shape<_8, _1, _1>{}; + } +} +#endif + +struct CutlassGemmConfig { + enum CandidateConfigTypeParam : int { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + BLACKWELL = 1u << 4, + GROUPED_GEMM = 1u << 5, + FP8_ONLY = 1u << 6, + FP4_ONLY = 1u << 7 + }; + + CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool enableCudaKernel = false; + int sm_version = 80; // Use 80 as a catch all for <90 + bool is_tma_warp_specialized = false; + + CutlassGemmConfig() = default; + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config_sm80(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages), sm_version(80) { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(90), is_tma_warp_specialized(true) { + } + + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm100(tile_config_sm100), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(100), is_tma_warp_specialized(true) { + } + + int getTileConfigAsInt() const { + if (sm_version == 120) + return (int)tile_config_sm80; + if (sm_version >= 100) + return (int)tile_config_sm100; + if (sm_version == 90) + return (int)tile_config_sm90; + if (sm_version < 90) + return (int)tile_config_sm80; + assert(false && "Invalid SM version"); + return -1; + } + + std::string toString() const { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (is_tma_warp_specialized && getTileConfigAsInt() != (int)CutlassTileConfigSM90::ChooseWithHeuristic) { + assert(sm_version >= 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA Warp Specialized" + << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule << "\n\tepi sched: " << (int)epilogue_schedule + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (tile_config_sm80 != onnxruntime::llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + assert(sm_version < 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int)tile_config_sm80 << "\n\tstages: " << (int)stages + << "\n\tsplit k: " << (int)split_k_factor + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (enableCudaKernel) { + tactic << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { + // clang-format off + if (config.is_tma_warp_specialized) + { + out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape) + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + else + { + out << "tile_config_enum: " << config.getTileConfigAsInt() + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000..86c45a865954e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter { +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000000..30df05f24257e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000000..cf5ebdaeec261 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 0000000000000..cc54764c2be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass { + +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc new file mode 100644 index 0000000000000..d53fb558ba1a1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc @@ -0,0 +1,479 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif // __GNUC__ + +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" +#include "core/common/common.h" + +#include +#include +#include +#include + +using namespace onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + ORT_THROW("[get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, + int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { + Default, + WeightOnly, + Simt, + Int8, + Fp8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } else { + if (sm == 89) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } else { + return {}; + } + } + default: + return base_configs; + } +} + +std::vector get_candidate_tiles_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + } else { + return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; + } +#endif +} + +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +std::vector get_candidate_configs_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { + auto tiles = get_candidate_tiles_sm90(config); + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config); + bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config); + if (has_m_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } + + if (has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig( + tiles[0], MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + // candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + // MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } + + for (int cluster_m = 1; cluster_m <= 2; cluster_m++) { + bool Is2SM = cluster_m == 2; + for (int cluster_n = 1; cluster_n <= 2; cluster_n++) { + std::vector base = {// M=128 + CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; + + if (Is2SM) { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); + base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); + } + + std::vector twosm = {// M=256 + CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; + std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); + } else { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); + if ((config & CutlassGemmConfig::FP8_ONLY) != 0) { + base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); + } + } + + if (cluster_n == 1 && cluster_m == 1 && ((config & CutlassGemmConfig::FP8_ONLY) != 0)) { + base.push_back(CutlassTileConfigSM100::CtaShape128x8x256B); + } + + std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B}; + std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); + } + + constexpr std::array, 2> cluster_shapes = + {{std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, + std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}}; + + auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; + for (auto tile : base) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); + } + } + } + return candidate_configs; + } else { + ORT_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif + +} // namespace kernels + +std::vector get_candidate_configs( + int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if ((config_type_param & CutlassGemmConfig::FP4_ONLY) && !(config_type_param & CutlassGemmConfig::BLACKWELL)) { + // FP4 is only supported on blackwell + return {}; + } + + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + return get_candidate_configs_sm90(config_type_param); + } + if (sm >= 100 && sm != 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm100(config_type_param); + } + + std::vector tiles = get_candidate_tiles(sm, config_type_param); + + std::vector candidate_configs; + bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { + auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig(tiles[0], SplitKStyle::NO_SPLIT_K, 1, min_stages); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config_sm80); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config_sm80 != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } + + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + float const current_score = float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config_sm80 == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("Heuristic failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h new file mode 100644 index 0000000000000..b9b0301d78fc7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/tensor.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +struct should_filter_tma_warp_specialized_gemm_problem_shape { +#ifdef FAST_BUILD + using SupportedCtaShape = cute::Shape(TileShape{}))>; + using SupportedCgaShape = cute::Shape; + + constexpr static bool value = !cute::is_same_v || !cute::is_same_v; +#else + constexpr static bool value = false; +#endif +}; +template +constexpr static bool should_filter_tma_warp_specialized_gemm_problem_shape_v = should_filter_tma_warp_specialized_gemm_problem_shape::value; + +std::vector get_candidate_configs( + int sm, int const max_split_k, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + +onnxruntime::llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc new file mode 100644 index 0000000000000..50ee944161538 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc @@ -0,0 +1,687 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" + +#include + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +using namespace onnxruntime::llm::common; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct LayoutDetails { + enum class Layout { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails { +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() { + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { + LayoutDetails details; + switch (quant_type) { + case QuantType::W8_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_AFP8: + details = getLayoutDetailsForArchAndQuantType(); + break; + default: + ORT_THROW("Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { + if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 90 && arch < 100) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 100) { + return getLayoutDetailsForArch(quant_type); + } else { + ORT_THROW("Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// +// Throws an error for other architectures. +// The data is permuted such that: +// For W8_A16, each group of 16 rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For W4_A16, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// For W4_A8, see the map in the code. The idea is similar to above. +// The goal of this permutation is to ensure data ends up in the correct threads after +// we execute LDSM. It counteracts the effect of the data being of different widths. +// For more information about the expected layouts, see the MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } else if (quant_type == QuantType::W4_A16) { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, + 22, 23, 30, 31}; + } else if (quant_type == QuantType::W4_AFP8) { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, + 28, 29, 30, 31}; + } else { + ORT_THROW("Invalid quantization type for LDSM permutation"); + } +} + +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const num_vec_cols = num_cols / elts_in_int32; + + ORT_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, + "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of ", + B_ROWS_PER_MMA); + ORT_ENFORCE(num_cols % MMA_SHAPE_N == 0, + "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of ", + MMA_SHAPE_N); + + ORT_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < static_cast(num_rows); base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; + + const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a substantial +// amount of time to transpose leading to long preprocessing times. This seemed to be a big +// issue for relatively large models. +template +void subbyte_transpose_impl( + int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + + uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples + // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it + // allows GCC to emit vector instructions. + ORT_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + "Number of bytes for rows and cols must be a multiple of ", VECTOR_WIDTH, ". However, num_rows_bytes = ", + col_bytes_trans, " and num_col_bytes = ", col_bytes); + + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte + jj; + + const size_t logical_src_offset = matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if constexpr (bits_per_elt == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if constexpr (bits_per_elt == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + ORT_THROW("Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + if (quant_type == QuantType::W8_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_AFP8) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { + ORT_THROW("Invalid quant_type"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) { + for (size_t ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no + // performance benefit and is purely so that int4 and int8 have the same layout. + // Pictorially, this does the following: + // bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + ORT_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) { + size_t const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + ORT_ENFORCE( + transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); + ORT_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + ORT_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + if (quant_type == QuantType::W8_A16) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + ORT_THROW("Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, LayoutDetails details) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + ORT_ENFORCE(!(num_rows % elts_in_int32), + "The number of rows must be a multiple of ", elts_in_int32, " but the number of rows is ", num_rows); + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); + + ORT_ENFORCE(!(num_rows % rows_per_tile), + "The number of rows must be a multiple of ", rows_per_tile, " but the number of rows is ", num_rows); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int64_t read_col = 0; read_col < static_cast(num_cols); ++read_col) { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) { + const int64_t vec_write_row = interleave * base_vec_row + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = matrix_offset + read_col * num_vec_rows + vec_read_row; + const int64_t write_offset = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + int arch = getSMVersion(); + if (force_interleave && arch >= 90) { + // Workaround for MOE which doesn't have specialized Hopper/Blackwell kernels yet + arch = 80; + } + // Force use sm80 kernel for GB20x. + if (arch >= 100) { + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (auto const& dim : shape) { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1 && arch != 90) { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave) { + ORT_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL"); + ORT_ENFORCE(scale_ptr, "Scale output pointer is NULL"); + ORT_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; + + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; + } + + for (size_t ii = 0; ii < num_rows; ++ii) { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (size_t ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + if (bits_per_weigtht_element == 8) { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (bits_per_weigtht_element == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < static_cast(num_cols)) { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + ORT_THROW("Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm( + processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + symmetric_quantize( + processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, float*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, half>( + int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h new file mode 100644 index 0000000000000..3e83852228e24 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +enum class QuantType { + W8_A16, + W4_A16, + W4_AFP8 +}; + +constexpr int get_weight_quant_bits(QuantType quant_type) { + switch (quant_type) { + case QuantType::W8_A16: + return 8; + case QuantType::W4_A16: + return 4; + case QuantType::W4_AFP8: + return 4; + default: + ORT_THROW("Invalid quant_type"); + return -1; + } +} + +// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] +// 3-D shapes are [num_experts, num_rows, num_cols] +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave = false); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave); + +// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight +// to implement a simple reference implementation. +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h new file mode 100644 index 0000000000000..1fe8035cbcdae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" + +#include "cutlass/half.h" +#include + +#include "cutlass/bfloat16.h" +#include + +#include "cutlass/float8.h" +#include + +#if defined(ENABLE_FP4) +#include "cutlass/float_subbyte.h" +#include +#endif + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +/////////////////////////////////////////////////////////////////////////////////////////////////// +// nvinfer::DataType to Cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassType { + using type = void; +}; + +template <> +struct CutlassType { + using type = cutlass::half_t; +}; + +template <> +struct CutlassType { + using type = cutlass::bfloat16_t; +}; + +template <> +struct CutlassType { + using type = cutlass::float_e4m3_t; +}; + +#if defined(ENABLE_FP4) +template <> +struct CutlassType { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// CUDA to Cutlass + +template +struct CudaToCutlassTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCutlassTypeAdapter { + using type = cutlass::half_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_bfloat16> { + using type = cutlass::bfloat16_t; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp4_e2m1> { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Cutlass to CUDA + +template +struct CudaToCudaTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = half; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_bfloat16; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e4m3; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e5m2; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp4_e2m1; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..47e662b9a88ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9452aa0e1fbe6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..4a22e0f1b2aac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9f4091be4cd07 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 0000000000000..0141c76bbc031 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include +#include + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// TRT Activation Type does not have Gelu or Silu +enum class ActivationType { + Gelu, + Relu, + Silu, + Identity, + InvalidType +}; + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. + In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor + will instantiate the layout and preprocess based on the instantiation, so layout changes should only require + modifications to mix_gemm_B_layout.h. +*/ + +class CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunnerInterface() {} + + virtual ~CutlassFpAIntBGemmRunnerInterface() {} + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, + int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + + virtual std::vector getConfigs() const = 0; + + protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 16; + static constexpr int MIN_N_TILE = 64; + + static constexpr int MAX_M_TILE_SM90 = 128; + static constexpr int MAX_N_TILE_SM90 = 256; +}; + +template +class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + // Disabled since the fused GEMM, activation kernels will not be used in v1. + + // void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, + // int k, ActivationType activation_type, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t + // stream); + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; + + std::vector getConfigs() const override; + + private: + template + void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 0000000000000..715397270331b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + static_assert( +#ifdef ENABLE_FP8 + cutlass::platform::is_same::value || +#endif + cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using EpilogueOp = + typename tkc::Epilogue::Op; + + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm, Stages, true, + TaggedOperator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + int const ldb = cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + if constexpr (cutlass::platform::is_same::value) { + if (group_size != 128) { + ORT_THROW("Only group size 128 supported for fine grained W4A(fp)8 kernels."); + } + } + if (group_size != 64 && group_size != 128) { + ORT_THROW("Only group size 64 and 128 supported for fine grained kernels."); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args({m, n, k}, group_size, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), ld_scale_zero}, + {reinterpret_cast(const_cast(weight_zero_points)), ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, gemm_config.split_k_factor, + {ElementAccumulator(alpha), output_op_beta}); + + // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of + // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write + // our own predicated iterator in order to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { + ORT_THROW("Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } +} + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. FP8 GEMM is only supported on Ada+ GPUs. +template +void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { + // FP8 activation type only supported on Ada+ GPUs + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else { + generic_mixed_gemm_kernelLauncher(A, B, weight_scales, weight_zero_points, biases, + alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + } +} + +template +void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.stages) { + case 2: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 3: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 4: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + break; + } +} + +template +constexpr bool is_fp8() { + return std::is_same_v || std::is_same_v; +} + +template +void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead. + constexpr bool any_is_fp8 = is_fp8() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); + + constexpr bool all_types_are_the_same = std::is_same_v && std::is_same_v && std::is_same_v; + + constexpr bool is_valid_pre_hopper = (all_types_are_the_same && !any_is_fp8) || (arch::kMinComputeCapability == 89); + + if constexpr (is_valid_pre_hopper) { + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the + // best for mixed type gemms. + constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits::value; + switch (gemm_config.tile_config_sm80) { + case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::Undefined: + ORT_THROW("[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } + } else { + // This is not a limitation in CUTLASS. We just do not need to support this case. + std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier."; + ORT_THROW("[fpA_intB_gemm] Error: [dispatch_gemm_to_cutlass] ", err_msg); + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + sm_ = ::onnxruntime::llm::common::getSMVersion(); + multi_processor_count_ = ::onnxruntime::llm::common::getMultiProcessorCount(); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // std::string config_str = gemm_config.toString(); + // printf("######## sm=%d, alpha: %f m:%d n:%d, k:%d, group_size:%d, workspace_bytes:%zu config:%s\n", sm_, alpha, m, n, k, group_size, workspace_bytes, config_str.c_str()); + + if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 89) { +#if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) + if constexpr (cutlass::platform::is_same::value) { + ORT_THROW( + "[fpA_intB_gemm] Error: INT4xFP8 GEMM for Ada needs CUDA>=12.4"); + } +#endif + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 90) { + static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "ScaleZeroType must be half for activation=fp8"); + sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, + workspace_bytes, gemm_config, stream, occupancy); + } else { + ORT_THROW("[fpA_intB_gemm] Error:Arch unsupported for CUTLASS mixed type GEMM"); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, (ScaleZeroType const*)weight_zero_points, (BiasType const*)biases, + alpha, (OutputType*)C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale, zero and group size only supported for fine grained bias template."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, weight_zero_points, biases, 1.f, C, m, n, k, group_size, gemmConfig, workspace_ptr, + workspace_bytes, stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, nullptr, nullptr, alpha, (OutputType*)C, m, n, k, k, gemmConfig, + workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale only (and no group size) only supported for per column scaling."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream); +} + +template +std::vector +CutlassFpAIntBGemmRunner::getConfigs() const { + static constexpr bool is_weight_only = !std::is_same::value; + tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; + if (is_weight_only) { + config_type_param = static_cast( + config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY); + } + std::vector candidateConfigs = get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); + return candidateConfigs; +} + +template +size_t +CutlassFpAIntBGemmRunner::getWorkspaceSize( + int const m, int const n, int const /*k*/) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // For Hopper, we have to allocate large memory size in case for stream-K + if (sm_ == 90) { + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 + // The above lines says sk_tiles = output_tiles - (static_cast(output_tiles / ctas_per_wave) - 1) * + // ctas_per_wave This means sk_tiles is at most 2 * ctas_per_wave, which is 2 * multi_processor_count_ + int const max_sk_tiles = 2 * multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L939 + // The above line says uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units); + // That means sk_units is at most ctas_per_sk_wave, which is multi_processor_count_ + int const max_sk_units = multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L505 + // The above lines scales sk_tiles by the factor of static_cast(sk_units / sk_tiles + 2) + // That means the final sk_tiles is at most 2 * max_sk_tiles + max_sk_units; + int const max_sk_tiles_with_separate_reduction = 2 * max_sk_tiles + max_sk_units; + + return static_cast( + max_sk_tiles_with_separate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float)); + } + // These are the min tile sizes for each config, which would launch the maximum number of blocks + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h new file mode 100644 index 0000000000000..432adb20079b6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. +template +void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.epilogue_schedule) { + case tkc::EpilogueScheduleType::AUTO: + using EpilogueScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>; + sm90_generic_mixed_gemm_kernelLauncher(A, B, weight_scales, + weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, + occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_epilogue_schedules] epilogue schedule config is invalid for " + "mixed type GEMM."); + break; + } +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions to improve compilation speed in TRT-LLM, by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() { + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) { + return true; + } else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) { + return true; + } else { + return false; + } +} + +template +void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + constexpr bool tile_shapes_supported = are_tile_shapes_supported(); + + if constexpr (tile_shapes_supported) { + switch (gemm_config.mainloop_schedule) { + case tkc::MainloopScheduleType::AUTO: + using KernelScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::gemm::KernelTmaWarpSpecializedCooperative>; + sm90_dispatch_epilogue_schedules(A, B, weight_scales, weight_zero_points, + biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] mainloop schedule config is invalid " + "for " + "mixed type GEMM."); + break; + } + } else { + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] Unsupported CTA and Cluster shapes for " + "mixed type GEMM."); + } +} + +template +void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.cluster_shape) { + case tkc::ClusterShape::ClusterShape_1x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_1x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW("[fpA_intB_gemm][dispatch_CGA_config] Config is invalid for mixed type GEMM."); + break; + } +} + +template +void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + + constexpr int Ktile = 128 / sizeof(ActivationType); + using _Ktile = Int; + switch (gemm_config.tile_config_sm90) { + case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::Undefined: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfigSM90::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu new file mode 100644 index 0000000000000..468d53f336e55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu @@ -0,0 +1,264 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu new file mode 100644 index 0000000000000..0156c83840b09 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu @@ -0,0 +1,516 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h new file mode 100644 index 0000000000000..594ae1079c06e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, + onnxruntime::llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl new file mode 100644 index 0000000000000..779ff88455703 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +#ifdef COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher( + ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig /*gemm_config*/, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + + if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) { + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Activation type must be bfloat16, half, FP8"); + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Weight type must be fp8, uint8_t or uint4_t"); + + static_assert(!std::is_same_v || + std::is_same_v, + "Scale/Zero type must be half for fp8 activation"); + + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // This example manually swaps and transposes, so keep transpose of input layouts + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + + using ElementZero = CutlassScaleZeroType; + using ElementScale = CutlassScaleZeroType; + + // C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast. + using LayoutBias = cutlass::layout::RowMajor; + constexpr int AlignmentBias = 128 / cutlass::sizeof_bits::value; + + // D matrix configuration + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for epilogue computation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using KernelSchedule = MainloopScheduleType; + using EpilogueSchedule = EpilogueScheduleType; + + // Shrink the N dimension to match CTA_N if needed + constexpr int epi_tile_M = cute::min(shape<0>(TileShape{}), 128); // 64 or 128 + constexpr int epi_tile_N = cute::min(shape<1>(TileShape{}), 32); // Allow this to be 16 for some small N tiles. + using EpilogueTileType = cute::Shape, cute::Int>; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + static_assert(std::is_same_v, ""); + using EVT_bias_addition = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, // alpha * acc + bias + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch, // acc + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, CutlassBiasType, CutlassBiasType, + Stride<_1, _0, _0>, + AlignmentBias> // bias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use the explicit swap + transpose trick + // Void C since we don't use it. Prevents smem allocation. + void, typename cutlass::layout::LayoutTranspose::type, AlignmentBias, CutlassOutputType, + typename cutlass::layout::LayoutTranspose::type, AlignmentOutput, EpilogueSchedule, + EVT_bias_addition>::CollectiveOp; + + using PackedScaleZero = cute::tuple; + using PackedScale = cute::tuple; + using ElementBCollectiveInfo = std::conditional_t; + + // We swap A and B operands to the builder here + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderInterleaved< + ArchTag, + OperatorClass, ElementBCollectiveInfo, LayoutB_Transpose, AlignmentB, CutlassActivationType, + LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using TileScheduler = cute::conditional_t(CTAShape{}) == Int<64>{}, cutlass::gemm::PersistentScheduler, + cutlass::gemm::StreamKScheduler>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + using StrideS = typename CollectiveMainloop::StrideScale; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + int cta_shape_k = cute::size<2>(TileShape{}); + if (group_size % cta_shape_k != 0) { + std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + auto cutlass_scale_k = (k + group_size - 1) / group_size; + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1)); + + // Use the output as the bias to avoid making a tma descriptor with a nullptr. + auto output_as_bias_type = reinterpret_cast(C); + + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, + {reinterpret_cast(B), stride_B, + reinterpret_cast(A), stride_A, + reinterpret_cast(weight_scales), stride_S, + group_size, reinterpret_cast(weight_zero_points)}, + {{}, output_as_bias_type, stride_D, reinterpret_cast(C), stride_D}}; + + args.epilogue.thread = { + {alpha}, // alpha args + {}, // accumulator + {reinterpret_cast(biases), CutlassBiasType(0.f)}, // bias args + {} // end multiply_add + }; + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_ERROR("[fpA_intB_gemm] given workspace size insufficient."); + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + } else { + std::stringstream ss; + ss << "[fpA_intB_gemm] Config (" << (int64_t)cute::size<0>(CTAShape{}) << "," + << (int64_t)cute::size<1>(CTAShape{}) << "," << (int64_t)cute::size<2>(CTAShape{}) << ") (" + << (int64_t)cute::size<0>(ClusterShape{}) << "," << (int64_t)cute::size<1>(ClusterShape{}) << "," + << (int64_t)cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD."; + + ORT_THROW(ss.str()); + } +} +#else // COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const*, WeightType const*, + ScaleZeroType const*, ScaleZeroType const*, BiasType const*, + float const, OutputType*, int, int, int, int const, tkc::CutlassGemmConfig, + char*, size_t, cudaStream_t, int*) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_THROW("[fpA_intB_gemm] Please recompile with support for hopper by passing 90a-real as an arch."); +} +#endif // COMPILE_HOPPER_TMA_GEMMS + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu new file mode 100644 index 0000000000000..55beb8b9ca029 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +__global__ void transposeScaleKernel( + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + int64_t output_offset = static_cast(out_row) * n + out_col; + T scale_val = scale[input_offset]; + transposed_scale[output_offset] = scale_val; + } +} + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + transposeScaleKernel<<>>( + scale, + transposed_scale, + n, + k_blocks); +} + +// CUDA kernel to compute -scale * zero_point and transpose +// Each thread computes one element of the OUTPUT matrix (shape [k_blocks, n]) +template +__global__ void computeScaledZeroPointAndTransposeKernel( + const Z* zero_point, // Input zero_point matrix [n, k_blocks] or [n, (k_blocks + 1) / 2] if packed int4 + const T* transposed_scale, // transposed scale [k_blocks, n] + T* scaled_zero_point, // Output matrix [k_blocks, n] + int n, // Rows of input matrices + int k_blocks, // Columns of input matrices + float default_zero_point) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t output_offset = static_cast(out_row) * n + out_col; + + // Perform the computation: scaled_zero_point[out_row, out_col] = -scale[in_row, in_col] * zero_point[in_row, in_col] + T scale_val = transposed_scale[output_offset]; + float zero_point_val; + if (zero_point != nullptr) { + if constexpr (is_zero_point_int4_packed) { // zero point is 4 bit, and two elements are packed into one byte. + int64_t packed_row_size = (k_blocks + 1) / 2; + int64_t packed_zp_offset = static_cast(in_row) * packed_row_size + in_col / 2; + uint8_t packed_zp = zero_point[packed_zp_offset]; + zero_point_val = static_cast((in_col & 0x01) ? (packed_zp >> 4) : (packed_zp & 0x0f)); + } else { + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + zero_point_val = static_cast(zero_point[input_offset]); + } + } else { + zero_point_val = default_zero_point; + } + + float result = static_cast(scale_val) * (-zero_point_val + default_zero_point); + scaled_zero_point[output_offset] = static_cast(result); + } +} + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point) { + assert(zero_point != nullptr); + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + computeScaledZeroPointAndTransposeKernel<<>>( + zero_point, + transposed_scale, + scaled_zero_point, + n, + k_blocks, + default_zero_point); +} + +// Explicit instantiations: +template void launch_transpose_scale_kernel( + cudaStream_t stream, + const half* scale, + half* transposed_scale, + int n, int k_blocks); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const half* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// zero point is 4 bits packed. +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// CUDA kernel to unpack uint4, transpose, and pack into int8 directly +__global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( + const unsigned char* __restrict__ packed_weight, + signed char* __restrict__ packed_transposed_weight, + int n, // original matrix rows + int k) // original matrix columns +{ + // The output 'packed_transposed_weight' has dimensions k x (n/2) bytes. + // Each thread processes one byte in the output. + int out_flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Total number of bytes in the output packed_transposed_weight matrix + int total_output_bytes = k * (n / 2); + + if (out_flat_idx < total_output_bytes) { + constexpr signed char default_zero_point = 8; + + // Calculate row and column in the output packed_transposed_weight matrix (k x n/2) + // out_row_packed: row in the k dimension of the output (0 to k-1) + // out_col_packed: column in the n/2 dimension of the output (0 to n/2 - 1) + const int out_row_packed = out_flat_idx / (n / 2); + const int out_col_packed = out_flat_idx % (n / 2); + + // These two int8 values will form the current output packed byte: + // val_0: corresponds to original_unpacked[2 * out_col_packed][out_row_packed] + // val_1: corresponds to original_unpacked[2 * out_col_packed + 1][out_row_packed] + + // --- Retrieve val_0 --- + // Its original (unpacked) row index was '2 * out_col_packed' + const int r_orig_0 = 2 * out_col_packed; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_0 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_0 resides + const int packed_weight_idx_0 = r_orig_0 * (k / 2) + c_orig_0 / 2; + + unsigned char packed_data_0 = packed_weight[packed_weight_idx_0]; + signed char val_0; + if ((c_orig_0 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_0 = (signed char)(packed_data_0 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_0 = (signed char)(packed_data_0 >> 4) - default_zero_point; + } + + // --- Retrieve val_1 --- + // Its original (unpacked) row index was '2 * out_col_packed + 1' + const int r_orig_1 = 2 * out_col_packed + 1; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_1 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_1 resides + const int packed_weight_idx_1 = r_orig_1 * (k / 2) + c_orig_1 / 2; + + unsigned char packed_data_1 = packed_weight[packed_weight_idx_1]; + signed char val_1; + if ((c_orig_1 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_1 = (signed char)(packed_data_1 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_1 = (signed char)(packed_data_1 >> 4) - default_zero_point; + } + + // Pack the two signed char values (now 8-bit, but we only care about their 4 LSBs) + // back into a single byte for the output. + packed_transposed_weight[out_flat_idx] = (unsigned char)((val_0 & 0x0f) | ((val_1 & 0x0f) << 4)); + } +} + +void unpack_uint4_transposed_to_int8_direct_cuda( + cudaStream_t stream, void* packed_transposed_weight, const void* packed_weight, int n, int k) { + int total_output_bytes = k * (n / 2); + int threads_per_block = 256; + int num_blocks = (total_output_bytes + threads_per_block - 1) / threads_per_block; + + unpack_transpose_pack_uint4_to_int8_kernel_v2<<>>( + (const unsigned char*)packed_weight, + (signed char*)packed_transposed_weight, + n, + k); +} + +__global__ void transpose_uint8_matrix_and_convert_to_int8_kernel( + const uint8_t* __restrict__ input, // shape: (n, k) + int8_t* __restrict__ output, // shape: (k, n) + int n, int k) { + + int row = blockIdx.y * blockDim.y + threadIdx.y; // index in n + int col = blockIdx.x * blockDim.x + threadIdx.x; // index in k + + if (row < n && col < k) { + int input_idx = row * k + col; + int output_idx = col * n + row; + output[output_idx] = static_cast(static_cast(input[input_idx]) - 128); + } +} + +void transpose_uint8_matrix_and_convert_to_int8( + cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k) { + + dim3 blockDim(16, 16); + dim3 gridDim((k + blockDim.x - 1) / blockDim.x, + (n + blockDim.y - 1) / blockDim.y); + + transpose_uint8_matrix_and_convert_to_int8_kernel<<>>(input, output, n, k); +} + + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h new file mode 100644 index 0000000000000..61023b62d8a49 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +// Convert scale and zero_point from MatMulNBits to the format required by fpA_intB_gemm or fpA_intB_gemv kernels. +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks); + +// Transpose uint4 weight matrix and add default zero points then pack as int8. +void unpack_uint4_transposed_to_int8_direct_cuda(cudaStream_t stream, + void* packed_transposed_weight, + const void* packed_weight, + int n, + int k); + +// Transpose uint8 weight matrix and add default zero points as int8. +void transpose_uint8_matrix_and_convert_to_int8(cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc new file mode 100644 index 0000000000000..8112562623791 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc @@ -0,0 +1,100 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/workspace.h" + +using namespace onnxruntime::llm::common; +using namespace onnxruntime::llm::kernels::cutlass_kernels; + +namespace onnxruntime::llm::kernels::weight_only { + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic( + int m, int n, int k, + WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + half* actPtr = reinterpret_cast(workspace); + void* weightPtr = nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half)); + half* inputScalesPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(float))); + half* zerosPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(inputScalesPtr), k * originalN * sizeof(half) / mGroupSize)); + half* biasesPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(zerosPtr), k * originalN * sizeof(half) / mGroupSize)); + half* outputPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(biasesPtr), n * sizeof(half))); + char* workspacePtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(outputPtr), m * originalN * sizeof(half))); + + if (!mHasZeros) { + zerosPtr = nullptr; + } + + if (!mHasBiases) { + biasesPtr = nullptr; + } + + if (tactic.enableCudaKernel) { + // run CUDA kernel + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + actPtr, pre_quant_scale_ptr, weightPtr, + inputScalesPtr, zerosPtr, + biasesPtr, outputPtr, + alpha, m, originalN, k, mGroupSize, mCudaKernelType, apply_alpha_in_advance); + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(mArch, params, stream); + } else { + // run CUTLASS kernel + int const wsSize = mRunner->getWorkspaceSize(m, originalN, k); + if (mQuantBits == 8) { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, + m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } else { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, + outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } + } +} + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { + // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16) + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + std::vector workspaces = { + maxM * k * sizeof(half), // A + k * n * sizeof(float), // B + k * originalN * sizeof(half) / mGroupSize, // scales + k * originalN * sizeof(half) / mGroupSize, // zeros + originalN * sizeof(half), // biases + maxM * originalN * sizeof(half), // C + mRunner->getWorkspaceSize(maxM, originalN, k) // workspace + }; + size_t bytes = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); + setTmpWorkspaceSizeInBytes(bytes); +} + +std::vector WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics( + int /*m*/, int /*n*/, int /*k*/) const { + return mRunner->getConfigs(); +} + +bool WeightOnlyGroupwiseQuantGemmPluginProfiler::checkTactic(int m, int /*n*/, int /*k*/, Config const& tactic) const { + // stop to profile Cuda kernel for m >= 16 + if (tactic.enableCudaKernel) { + return m < 16; + } + return true; +} + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h new file mode 100644 index 0000000000000..7be77fa43d85d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h @@ -0,0 +1,86 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +#include +#include +#include +#include +#include +#include + +using WeightOnlyGemmRunner = onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; +using KernelType = onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + +namespace onnxruntime::llm::kernels::weight_only { +enum class WeightTypeId { + INT8 = 1, + INT4 = 2, +}; + +constexpr int32_t FP16_BITS = 16; +constexpr int32_t INT8_BITS = 8; +constexpr int32_t INT4_BITS = 4; +constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS; +constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS; + +class WeightOnlyGroupwiseQuantGemmPluginProfiler + : public GemmPluginProfiler { + public: + using Config = onnxruntime::llm::cutlass_extensions::CutlassGemmConfig; + + void setQuant(int bits, bool has_bias, bool has_zeros) { + mQuantBits = bits; + mHasBiases = has_bias; + mHasZeros = has_zeros; + } + + void setGroupSize(int groupSize) { + mGroupSize = groupSize; + } + + void setCudaKernelType(KernelType cudaKernelType, int arch) { + mCudaKernelType = cudaKernelType; + mArch = arch; + } + + protected: + void runTactic(int m, int n, int k, Config const& tactic, + char* workspace, cudaStream_t const& stream) override; + + void computeTmpSize(size_t maxM, size_t n, size_t k) override; + + std::vector getTactics(int m, int n, int k) const override; + + bool checkTactic(int m, int n, int k, Config const& tactic) const override; + + private: + bool mHasBiases; + bool mHasZeros; + int mQuantBits; + int mGroupSize; + KernelType mCudaKernelType; + int mArch; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h new file mode 100644 index 0000000000000..4fa64ef329c57 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct kernel_type_traits; +#define KERNEL_TYPE_TRAITS_REGISTRY(KT, _isGroupwise, _isInt4) \ + template <> \ + struct kernel_type_traits { \ + static constexpr bool isGroupwise = _isGroupwise; \ + static constexpr bool isInt4 = _isInt4; \ + }; + +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4PerChannel, false, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4PerChannel, false, true); +#undef KERNEL_TYPE_TRAITS_REGISTRY + +// A generic memory iterator used for coalesced global memory access with optional enablement. +// Template parameters: +// Enable: If false, disables loading/storing. +// TVec: Vectorized type (e.g., float4, half2). +// Strided: Number of rows in a tile. +// Continuous: Number of contiguous vector elements to load/store at once. +// Scalar type (e.g., half). +template +class GMemIterator { + public: + __device__ __forceinline__ GMemIterator(T* addr, int offset, int step, int stride) + : addr_(Enable ? (addr + offset) : nullptr), step_(step), stride_(stride) { + } + + __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { + if constexpr (Enable) { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) { + reinterpret_cast(dst)[jj] = reinterpret_cast(addr_ + iter * step_ + ii * stride_)[jj]; + } + } + } + + private: + T* addr_; + int step_; + int stride_; +}; + +struct FP16DetailsA { + using Type = half; + using Type2 = half2; + static constexpr int kElemBits = 16; +}; + +struct BF16DetailsA { + using Type = __nv_bfloat16; + using Type2 = __nv_bfloat162; + static constexpr int kElemBits = 16; +}; + +struct Int8DetailsW { + static constexpr int kElemBits = 8; +}; + +struct Int4DetailsW { + static constexpr int kElemBits = 4; +}; + +template +struct ColumnMajor { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i; + } + }; +}; + +template +struct ColumnMajorInterleavedForHopper { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + static constexpr int kTypeFactor = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kTypeFactor * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template +struct ColumnMajorInterleaved { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kInterleave * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template class LayoutDetails_, + bool UseInterleavedConverter, int TileSizeK> +struct KernelDetails { + using TypeDetailsA = TypeDetailsA_; + using TypeDetailsW = TypeDetailsW_; + using LayoutDetails = LayoutDetails_; + using AccessTypeA = typename LayoutDetails::AccessTypeA; + using AccessTypeW = typename LayoutDetails::AccessTypeW; + static constexpr int kWarpSize = 32; + static constexpr int kStepK = LayoutDetails::kStepK; + static constexpr int kAccessNumA = kStepK * TypeDetailsA::kElemBits / (sizeof(AccessTypeA) * 8); + static constexpr int kAccessNumW = kStepK * TypeDetailsW::kElemBits / (sizeof(AccessTypeW) * 8); + static constexpr int kInterleave = LayoutDetails::kInterleave; + static constexpr int kThreadsPerInterleavedTile = LayoutDetails::kTileSize / kStepK; + static constexpr int kElemsPerByteW = 8 / TypeDetailsW::kElemBits; + static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; +}; + +template +struct I2FConverter; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct ConverterWrapper { + using TypeDetailsA = typename Details::TypeDetailsA; + using TypeDetailsW = typename Details::TypeDetailsW; + static constexpr bool kUseInterleavedConverter = Details::kUseInterleavedConverter; + using Converter = I2FConverter; +}; + +template +void select_gs(Params& params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h new file mode 100644 index 0000000000000..ff1a28661184f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct MathWrapper { +}; + +template <> +struct MathWrapper { + using Type = typename FP16DetailsA::Type; + using Type2 = typename FP16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { + return __half2half2(v); + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { + return __hfma2(a, b, c); + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { + return __hmul2(a, b); + } + + // __device__ __forceinline__ static Type2 deq2(Type2 const& weight, Type2 const& scale, Type2 const& zero_point) { + // return __hmul2(__hsub2(weight, zero_point), scale); + // } +}; + +template <> +struct MathWrapper { + using Type = typename BF16DetailsA::Type; + using Type2 = typename BF16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __bfloat162bfloat162(v); +#else + uint32_t val = 0; + Type2 ret = reinterpret_cast(val); + return ret; +#endif + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hfma2(a, b, c); +#else + return to_vec2(static_cast(0.f)); +#endif + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmul2(a, b); +#else + return to_vec2(static_cast(0.f)); +#endif + } +}; + +template +__device__ __forceinline__ void apply_scale(void* act, void* act_scale) { + using Type2 = typename MathWrapper::Type2; + static_assert(K % 2 == 0); + [[maybe_unused]] static constexpr int VecK = K / 2; + if constexpr (Enable) { + Type2* pa = reinterpret_cast(act); + Type2* pb = reinterpret_cast(act_scale); +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int k = 0; k < VecK; ++k) { + pa[m * VecK + k] = MathWrapper::mul2(pa[m * VecK + k], pb[k]); + } + } + } +} + +template +__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, float alpha) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + using Converter = typename ConverterWrapper
::Converter; + static_assert(K % 2 == 0); + static constexpr int VecK = K / 2; +#pragma unroll + for (int n = 0; n < N; ++n) { + Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, + reinterpret_cast(w) + n * K); + Type2 vec_scale, vec_zero; + if constexpr (ApplyAlphaInAdvance) { + // For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16 + Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; + vec_scale = MathWrapper::to_vec2(scales_); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2( + static_cast(reinterpret_cast(zeros)[n]) * alpha); + } + } else { + vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + } + } +#pragma unroll + for (int k = 0; k < VecK; ++k) { + reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( + reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + } + } +} + +template +__device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) { + using Type = typename MathWrapper::Type; + typename Details::LayoutDetails::Mapper mapper; + int n0 = n & ~0x1, n1 = n & 0x1; + for (int k = 0; k < K; ++k) { + int physical_idx = mapper(k); + reinterpret_cast(dst)[n0 * K + k * 2 + n1] = reinterpret_cast(src)[physical_idx]; + } +} + +template +__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + static_assert(N % 2 == 0); + static constexpr int VecN = N / 2; +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int n = 0; n < VecN; ++n) { +#pragma unroll + for (int k = 0; k < K; ++k) { + reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( + reinterpret_cast(w_pack2)[n * K + k], + MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), + reinterpret_cast(acc)[m * VecN + n]); + } + } + } +} + +template +__device__ __forceinline__ T warp_reduce_sum(T& val) { + val += __shfl_xor_sync(~0, val, 16); + val += __shfl_xor_sync(~0, val, 8); + if (Interleave != 2 && Interleave != 4) + val += __shfl_xor_sync(~0, val, 4); + if (Interleave != 4) + val += __shfl_xor_sync(~0, val, 2); + val += __shfl_xor_sync(~0, val, 1); + return val; +} + +template +__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) { + using Type = typename MathWrapper::Type; + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(Threads % WarpSize == 0); + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize, lane_id = tid % WarpSize; +#pragma unroll + for (int m = 0; m < CtaM; ++m) { +#pragma unroll + for (int n = 0; n < CtaN; ++n) { + float v = static_cast(reinterpret_cast(tile_acc)[m * CtaN + n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { + shmem[warp_id * CtaM * CtaN * Interleave + m * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; + } + } + } + __syncthreads(); +#pragma unroll + for (int ii = tid; ii < CtaM * CtaN * Interleave; ii += Threads) { + int m = ii / (CtaN * Interleave), n = ii % (CtaN * Interleave); + float val = 0.f, v_bias = 0.f; + if constexpr (EnableBias) { + v_bias = static_cast(reinterpret_cast(bias)[n]); + } +#pragma unroll + for (int jj = 0; jj < WarpNum; ++jj) { + val += shmem[jj * CtaM * CtaN * Interleave + ii]; + } + if constexpr (ApplyAlphaInAdvance) { + reinterpret_cast(out)[m * stride + n] = static_cast(val + v_bias); + } else { + reinterpret_cast(out)[m * stride + n] = static_cast(alpha * val + v_bias); + } + } +} + +template +__device__ __forceinline__ void fill(void* tile, T v) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + reinterpret_cast(tile)[ii] = v; + } +} + +template +__global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* scales, TypeA* zeros, TypeA* bias, + TypeA* out, float alpha, int m, int n, int k) { + // ArgType ArgName DataType Shape Layout + // input act fp16/bf16 [m, k] RowMajor + // input act_scale fp16/bf16 [1, k] RowMajor + // input weight int4b/int8b [k, n] ColumnMajor or ColumnMajorInterleaved + // input scales fp16/bf16 [k / GroupSize, n] RowMajor + // input zeros fp16/bf16 [k / GroupSize, n] RowMajor + // input bias fp16/bf16 [1, n] RowMajor + // output out fp16/bf16 [m, n] RowMajor + + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); + } + + int const origin_k = k, interleaved_k = k * Details::kInterleave; + + int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator act_scale_iterator( + act_scale, real_offset_k, CtaK / Details::kInterleave, 0); + GMemIterator weight_iterator( + weight, + (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, + interleaved_k / Details::kElemsPerByteW); + + GMemIterator scales_iterator( + scales, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + GMemIterator zeros_iterator( + zeros, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; + if constexpr (EnableBias) { + bias += tile_id_n * CtaN * Details::kInterleave; + } + + TypeA tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); + + for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { + TypeA vec_act_scale[StepK]; + TypeA vec_scale[CtaN], vec_zero[CtaN]; + TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); + } + act_scale_iterator.load(vec_act_scale, iter); +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + weight_iterator.load(tile_w_quantized, iter, i); + dequantize( + tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + pack_to_vec2(tile_w_pack2, tile_w, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) { + act_iterator.load(tile_a, iter, i); + apply_scale(tile_a, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + } + } + epilogue(out, n, tile_acc, bias, alpha); +} + +template +void exec_kernel(Params& params, cudaStream_t s) { + using T = typename Details::TypeDetailsA::Type; + if (params.m % CtaM || params.n % (CtaN * Details::kInterleave)) { + throw std::runtime_error("launch failed"); + } + dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); + dim3 block(Threads); + kernel<<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.weight), + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.bias), + reinterpret_cast(params.out), + params.alpha, + params.m, params.n, params.k); +} + +template +void dispatcher(Params& params, cudaStream_t s) { +#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \ + do { \ + if (params.m == target_m) { \ + exec_kernel(params, s); \ + return; \ + } \ + } while (0); + + if constexpr (EnableZero) { + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + DISPATCHER_FOR_M(5, 5, 4, 128); + DISPATCHER_FOR_M(6, 6, 4, 128); + DISPATCHER_FOR_M(7, 7, 4, 128); + DISPATCHER_FOR_M(8, 8, 4, 128); + DISPATCHER_FOR_M(9, 9, 4, 128); + DISPATCHER_FOR_M(10, 10, 4, 128); + DISPATCHER_FOR_M(11, 11, 4, 128); + DISPATCHER_FOR_M(12, 12, 4, 128); + DISPATCHER_FOR_M(13, 13, 4, 128); + DISPATCHER_FOR_M(14, 14, 4, 128); + DISPATCHER_FOR_M(15, 15, 4, 128); + } else { + DISPATCHER_FOR_M(1, 1, 8, 128); + DISPATCHER_FOR_M(2, 2, 8, 128); + DISPATCHER_FOR_M(3, 3, 8, 128); + DISPATCHER_FOR_M(4, 4, 8, 128); + DISPATCHER_FOR_M(5, 5, 8, 128); + DISPATCHER_FOR_M(6, 6, 8, 128); + DISPATCHER_FOR_M(7, 7, 8, 128); + DISPATCHER_FOR_M(8, 8, 8, 128); + DISPATCHER_FOR_M(9, 9, 8, 128); + DISPATCHER_FOR_M(10, 10, 8, 128); + DISPATCHER_FOR_M(11, 11, 8, 128); + DISPATCHER_FOR_M(12, 12, 8, 128); + DISPATCHER_FOR_M(13, 13, 8, 128); + DISPATCHER_FOR_M(14, 14, 8, 128); + DISPATCHER_FOR_M(15, 15, 8, 128); + } + throw std::runtime_error("unsupported m"); +#undef DISPATCHER_FOR_M +} + +template +void check_pointer(Params& params, cudaStream_t s) { + assert(!params.act_scale); // act_scale is not supported for now. + assert(!params.apply_alpha_in_advance); // apply_alpha_in_advance is not supported for now. + + if (params.zeros && params.bias) { + dispatcher(params, s); + } else if (!params.zeros && params.bias) { + dispatcher(params, s); + } else if (params.zeros && !params.bias) { + dispatcher(params, s); + } else { + dispatcher(params, s); + } +} + +template +void select_gs(Params& params, cudaStream_t s) { + if constexpr (isGroupwise) { + if (params.groupsize == 64) { + check_pointer(params, s); + return; + } else if (params.groupsize == 128) { + check_pointer(params, s); + return; + } + } + + ORT_THROW("unsupported block_size: ", params.groupsize); +} + +#define INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KType, A, B, Layout, ConverterInterleave, KTile) \ + template void select_gs::isGroupwise, \ + KernelDetails>(Params & params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu new file mode 100644 index 0000000000000..e2c008884c998 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu new file mode 100644 index 0000000000000..8cd96c44421e5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu new file mode 100644 index 0000000000000..1eb5f51bdffdc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu new file mode 100644 index 0000000000000..f5872841e1acb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu new file mode 100644 index 0000000000000..f6b76e67b20ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu new file mode 100644 index 0000000000000..2ca88285d4cfe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu new file mode 100644 index 0000000000000..7a00e1ba35f80 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu new file mode 100644 index 0000000000000..4a8506ca6bbde --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu new file mode 100644 index 0000000000000..32cd607d36480 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +void kernel_launcher(int arch, Params& params, cudaStream_t s) { +#define EXEC(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + +// This is not used since there is no alpha for MatMulNBits currently. +#define EXEC_W4A8(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType && params.apply_alpha_in_advance) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + + if (arch >= 75 && arch < 80) { + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 80 && arch < 90 || arch >= 100) { + // if (arch == 89 || arch >= 120) + // { + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // } + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 90) { + // Dispatchers for W4A8 groupwise + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + } +#undef EXEC_W4A8 +#undef EXEC +} + +bool is_supported(int arch, KernelType kernel_type) { +#define SUPPORT(Type) \ + if (kernel_type == Type) \ + return true; + + if (arch >= 75 && arch < 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + } else if (arch >= 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + + SUPPORT(KernelType::BF16Int8Groupwise); + SUPPORT(KernelType::BF16Int4Groupwise); + } + return false; +#undef SUPPORT +} + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h new file mode 100644 index 0000000000000..db2860c6b265c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +enum class KernelType { + FP16Int8Groupwise, + FP16Int4Groupwise, + FP16Int8PerChannel, + FP16Int4PerChannel, + BF16Int8Groupwise, + BF16Int4Groupwise, + BF16Int8PerChannel, + BF16Int4PerChannel +}; + +struct Params { + using Pointer = void*; + using ConstPointer = void const*; + Pointer act; + Pointer act_scale; + Pointer weight; + Pointer scales; + Pointer zeros; + Pointer bias; + Pointer out; + float alpha; + int m; + int n; + int k; + int groupsize; + KernelType type; + bool apply_alpha_in_advance; + + Params(ConstPointer _act, ConstPointer _act_scale, ConstPointer _weight, ConstPointer _scales, ConstPointer _zeros, + ConstPointer _bias, Pointer _out, float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type, + bool _apply_alpha_in_advance = false) + : act(const_cast(_act)), + act_scale(const_cast(_act_scale)), + weight(const_cast(_weight)), + scales(const_cast(_scales)), + zeros(const_cast(_zeros)), + bias(const_cast(_bias)), + out(_out), + alpha(_alpha), + m(_m), + n(_n), + k(_k), + groupsize(_groupsize), + type(_type), + apply_alpha_in_advance(_apply_alpha_in_advance) { + } +}; + +void kernel_launcher(int arch, Params& params, cudaStream_t s); + +bool is_supported(int arch, KernelType kernel_type); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc new file mode 100644 index 0000000000000..893ff27c068f8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc @@ -0,0 +1,311 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +#include + +namespace onnxruntime::llm::kernels::weight_only { + +template +GemmPluginProfiler::GemmPluginProfiler() { + mMNKProfileMap = std::make_shared(); + + // set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings + auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); + mSkip = (skipEnv != NULL && std::stoi(skipEnv)); + if (mSkip) { + ORT_LLM_LOG_DEBUG( + "SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error " + "if default tactic is not defined."); + } +} + +// template +// void GemmPluginProfiler::serialize( +// char*& buffer, GemmIdType const& gemmId) const +// { +// auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + +// // Save number of profiles for given GEMM ID +// write(buffer, static_cast(mProfileMap->size())); +// for (auto const& pair : *mProfileMap) +// { +// // Save pair of M to the best GEMM config +// write(buffer, pair); +// } +// } + +// template +// void GemmPluginProfiler::deserialize( +// char const*& data, GemmDims& dims, GemmIdType const& gemmId) +// { +// // NOTE: this mutex is not needed since each thread owns its private map, but will put here for +// // consistency +// writer_lock lock(mMNKProfileMap->mutex); + +// mDims = dims; + +// // GemmId gemmId(dims.n, dims.k); +// if (!mMNKProfileMap->existsMProfileMap(gemmId)) +// { +// // Create GEMM with GEMM ID if it does not exist +// mMNKProfileMap->createMProfileMap(gemmId); +// } +// // Populate map with profiles of GEMM ID +// auto profileMap = mMNKProfileMap->getMProfileMap(gemmId); +// int selectedMapSize; +// read(data, selectedMapSize); +// for (int ii = 0; ii < selectedMapSize; ++ii) +// { +// std::pair> config; +// read(data, config); +// profileMap->insert(config); +// } +// } + +// template +// size_t GemmPluginProfiler::getSerializationSize( +// GemmIdType const& gemmId) const +// { +// reader_lock lock(mMNKProfileMap->mutex); +// return sizeof(int) + // size of the tactics map +// mMNKProfileMap->getMProfileMap(gemmId)->size() +// * sizeof(std::pair>); // size of the tactics map +// } + +template +int GemmPluginProfiler::getMaxProfileM() const { + return 8192; +} + +template +void GemmPluginProfiler::initTmpData( + int /*m*/, int /*n*/, int /*k*/, char* /*workspace*/, size_t /*size*/, cudaStream_t /*stream*/) { + /* Do nothing */ +} + +template +void GemmPluginProfiler::profileTactics( + RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId, + bool hasWeightOnlyCudaKernel) { + writer_lock lock(mMNKProfileMap->mutex); + + if (!dims.isInitialized()) { + return; + } + + mRunner = runner; + mType = type; + + int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM()); + computeTmpSize(maxM, dims.n, dims.k); + + if (!mMNKProfileMap->existsMProfileMap(gemmId)) { + // Create map for GEMM ID + mMNKProfileMap->createMProfileMap(gemmId); + } + + if (mSkip) { + return; + } + + auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + bool isAllocated{false}; + + auto profileTactics = [&mProfileMap, &isAllocated, this](int m, int n, int k) { + if (mProfileMap->count(m) == 0) { + if (!isAllocated) { + // Allocate tmp data to run GEMMs + allocateTmpData(); + isAllocated = true; + } + initTmpData(m, n, k, mWorkspaceTmp, mTmpWorkspaceSizeInBytes, mStream); + auto tactics = this->getTactics(m, n, k); + + // Profile different tactics for particular m and insert best config to the map + mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)}); + } + }; + + CUDA_CALL_THROW(cudaStreamCreate(&mStream)); + + int const startMinMRounded = nextPowerOfTwo(dims.minM); + + if (hasWeightOnlyCudaKernel) { + // Profile tactics for finer granularity of M, + // if CUDA kernel is enabled for weight-only plugins + int minM = dims.minM; + for (int m = std::max(1, minM); m < std::min(16, maxM); m += 1) { + profileTactics(m, dims.n, dims.k); + } + + for (int m = 16; m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } else { + // Profile tactics for CUTLASS kernel only + for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } + + profileTactics(maxM, dims.n, dims.k); + + if (isAllocated) { + // Free tmp data + freeTmpData(); + } + CUDA_CALL_THROW(cudaStreamDestroy(mStream)); +} + +template +std::optional GemmPluginProfiler::getBestConfig( + int m, GemmIdType const& gemmId) const { + reader_lock lock(mMNKProfileMap->mutex); + + if (mSkip) { + ORT_LLM_LOG_TRACE("Skip is set, no best config is set for this instance"); + return std::nullopt; + } + + int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM()); + fflush(stdout); + + if (mMNKProfileMap->getMProfileMap(gemmId)->count(m) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(m); + } else if (mMNKProfileMap->getMProfileMap(gemmId)->count(mRounded) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded); + } else { + std::ostringstream msg; + msg << "Cannot find best tactic for m=" << m << " and GEMM ID " << gemmId; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } +} + +template +void GemmPluginProfiler::allocateTmpData() { + ORT_ENFORCE(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0"); + auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); + ORT_ENFORCE(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling."); +} + +template +void GemmPluginProfiler::freeTmpData() { + auto const status = cudaFree(mWorkspaceTmp); + ORT_ENFORCE(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling."); +} + +template +std::optional GemmPluginProfiler::profileTacticsForProblem( + int m, int n, int k, std::vector const& tactics) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + float bestTime = std::numeric_limits::max(); + Config bestConfig; + bool foundOne = false; + + // Iterate over all tactics for given M, N and K + for (size_t ii = 0; ii < tactics.size(); ++ii) { + Config const& candidateConfig = tactics[ii]; + float time = std::numeric_limits::max(); + try { + if (!checkTactic(m, n, k, candidateConfig)) { + continue; + } + // Profile particular tactic for given M, N and K + time = profileTacticForProblem(m, n, k, candidateConfig); + foundOne = true; + } catch (std::exception const& e) { + std::ostringstream msg; + msg << "Cannot profile configuration " << ii; + if constexpr (std::is_same_v) { + msg << ": " << candidateConfig.toString(); + } + msg << "\n (for" + << " m=" << m << ", n=" << n << ", k=" << k << ")" + << ", reason: \"" << e.what() << "\". Skipped"; + ORT_LLM_LOG_TRACE(msg.str()); + cudaGetLastError(); // Reset the last cudaError to cudaSuccess. + continue; + } + + // Choose the fastest tactic + if (time < bestTime) { + bestConfig = candidateConfig; + bestTime = time; + } + } + + if (!foundOne) { + std::ostringstream msg; + msg << "Have not found any valid GEMM config for shape (" + << "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime"; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } + + return {bestConfig}; +} + +template +float GemmPluginProfiler::profileTacticForProblem( + int m, int n, int k, Config const& tactic) { + constexpr int warmup = 5; + constexpr int runs = 10; + + cudaStream_t stream = mStream; + + // Warmup the execution + for (int i = 0; i < warmup; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + cudaEvent_t start; + cudaEvent_t stop; + CUDA_CALL_THROW(cudaEventCreate(&start)); + CUDA_CALL_THROW(cudaEventCreate(&stop)); + CUDA_CALL_THROW(cudaStreamSynchronize(stream)); + CUDA_CALL_THROW(cudaEventRecord(start, stream)); + + // Profile GEMM + for (int i = 0; i < runs; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + CUDA_CALL_THROW(cudaEventRecord(stop, stream)); + + CUDA_CALL_THROW(cudaEventSynchronize(stop)); + + float elapsed; + CUDA_CALL_THROW(cudaEventElapsedTime(&elapsed, start, stop)); + + CUDA_CALL_THROW(cudaEventDestroy(start)); + CUDA_CALL_THROW(cudaEventDestroy(stop)); + + return elapsed / runs; +} + +template class GemmPluginProfiler, GemmIdCore, + GemmIdCoreHash>; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h new file mode 100644 index 0000000000000..0ab9b91e7f43c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" +#include "core/common/common.h" + +namespace onnxruntime::llm::kernels::weight_only { + +struct GemmDims { + int64_t minM; + int64_t maxM; + int64_t n; + int64_t k; + + GemmDims() + : minM(-1), maxM(-1), n(-1), k(-1) { + } + + GemmDims(int64_t minM_, int64_t maxM_, int64_t n_, int64_t k_) + : minM(minM_), maxM(maxM_), n(n_), k(k_) { + } + + [[nodiscard]] bool isInitialized() const { + return minM >= 0 && maxM >= 0 && n >= 0 && k >= 0; + } +}; + +// Unique ID of GEMM +// In our case GEMM is uniqly identified by N and K +class GemmIdCore { + public: + int n; + int k; + nvinfer::DataType dtype; + + GemmIdCore(int n_, int k_, nvinfer::DataType const& dtype_) + : n(n_), k(k_), dtype(dtype_) { + } + + GemmIdCore() + : n(-1), k(-1), dtype(nvinfer::DataType::kFLOAT) // dtype does not matter here + { + } + + bool operator==(GemmIdCore const& id) const { + return isEqual(id); + } + + friend std::ostream& operator<<(std::ostream& out, GemmIdCore const& id) { + out << "(N;K)=(" << id.n << ";" << id.k << "),"; + out << " type=" << static_cast(id.dtype); + return out; + } + + protected: + bool isEqual(GemmIdCore const& id) const { + return n == id.n && k == id.k && dtype == id.dtype; + } +}; + +// Hash of GemmId +struct GemmIdCoreHash { + std::size_t operator()(GemmIdCore const& id) const { + auto h1 = std::hash{}(id.n); + auto h2 = std::hash{}(id.k); + auto h3 = std::hash{}(static_cast(id.dtype)); + return h1 ^ h2 ^ h3; + } +}; + +// class GemmIdCublas : public GemmIdCore { +// public: +// bool transA{}; +// bool transB{}; +// nvinfer::DataType outputDtype; + +// GemmIdCublas(int n_, int k_, nvinfer::DataType const& dtype_, bool transA_, bool transB_, +// nvinfer::DataType const& output_dtype_) +// : GemmIdCore(n_, k_, dtype_), transA(transA_), transB(transB_), outputDtype(output_dtype_) { +// } + +// GemmIdCublas() {} + +// bool operator==(GemmIdCublas const& id) const { +// return isEqual(id) && transA == id.transA && transB == id.transB && outputDtype == id.outputDtype; +// } + +// friend std::ostream& operator<<(std::ostream& out, GemmIdCublas const& id) { +// out << "(N;K)=(" << id.n << ";" << id.k << "),"; +// out << " type=" << static_cast(id.dtype); +// out << " transA=" << id.transA; +// out << " transB=" << id.transB; +// out << " outputDtype=" << static_cast(id.outputDtype); +// return out; +// } +// }; + +// // Hash of GemmIdCublas +// struct GemmIdCublasHash { +// std::size_t operator()(GemmIdCublas const& id) const { +// auto h1 = std::hash{}(id.n); +// auto h2 = std::hash{}(id.k); +// auto h3 = std::hash{}(static_cast(id.dtype)); +// auto h4 = std::hash{}(id.transA); +// auto h5 = std::hash{}(id.transB); +// auto h6 = std::hash{}(static_cast(id.outputDtype)); +// return h1 ^ h2 ^ h3 ^ h4 ^ h5 ^ h6; +// } +// }; + +template +class GemmPluginProfiler { + public: + // Map for single GEMM for different Ms (GEMM dimension) to the best config for particular M + using MProfileMap = std::unordered_map>; + using MProfileMapPtr = std::shared_ptr; + + // requires exclusive ownership to write to *this + using reader_lock = std::unique_lock; + // requires shared ownership to read from other + using writer_lock = std::shared_lock; + + // Struct of continuing map if GEMMs to the best profiles for different Ms + struct MNKProfileMap { + // Mutex guarding map + std::shared_timed_mutex mutex; + // Map from GEMM Id to profile for particular GEMM + std::unordered_map profileMap; + + bool existsMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + return iter != profileMap.end(); + } + + void createMProfileMap(GemmIdType const& id) { + profileMap[id] = std::make_shared(); + } + + MProfileMapPtr getMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + if (iter == profileMap.end()) { + ORT_THROW("Cannot find ID (", id, ") in the profile map. Abort."); + } + return iter->second; + } + }; + + using MNKProfileMapPtr = std::shared_ptr; + + GemmPluginProfiler(); + + virtual ~GemmPluginProfiler() = default; + + // void serialize(char*& buffer, GemmIdType const& gemmId) const; + + // void deserialize(char const*& data, GemmDims& dims, GemmIdType const& gemmId); + // size_t getSerializationSize(GemmIdType const& gemmId) const; + + void profileTactics(RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, + GemmIdType const& gemmId, bool hasWeightOnlyCudaKernel = false); + + void setSelectionTactics(MNKProfileMapPtr const& map) { + mMNKProfileMap = map; + } + + void setTmpWorkspaceSizeInBytes(size_t bytes) { + mTmpWorkspaceSizeInBytes = bytes; + } + + void setSkip(bool skip) { + mSkip = mSkip || skip; + } + + std::optional getBestConfig(int m, GemmIdType const& gemmId) const; + + virtual int getMaxProfileM() const; + + protected: + virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; + + virtual void computeTmpSize(size_t maxM, size_t n, size_t k) = 0; + + virtual bool checkTactic(int /*m*/, int /*n*/, int /*k*/, Config const& /*tactic*/) const { + return true; + } + + virtual std::vector getTactics(int m, int n, int k) const = 0; + + virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream); + + private: + void allocateTmpData(); + + void freeTmpData(); + + std::optional profileTacticsForProblem(int m, int n, int k, std::vector const& tactics); + + float profileTacticForProblem(int m, int n, int k, Config const& tactic); + + int nextPowerOfTwo(int v) const { + --v; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + return ++v; + } + + protected: + RunnerPtr mRunner{nullptr}; + + nvinfer::DataType mType{}; + + private: + MNKProfileMapPtr mMNKProfileMap{}; + + size_t mTmpWorkspaceSizeInBytes{0}; + + char* mWorkspaceTmp{nullptr}; + + cudaStream_t mStream; + + GemmDims mDims{}; + + bool mSkip{false}; +}; + +template +class GemmPluginProfilerManager { + public: + using MNKProfileMap = typename GemmPluginProfilerType::MNKProfileMap; + using MNKProfileMapPtr = typename GemmPluginProfilerType::MNKProfileMapPtr; + using GemmPluginProfilerPtr = std::shared_ptr; + + GemmPluginProfilerManager() { + mMNKProfileMap = std::make_shared(); + } + + GemmPluginProfilerPtr createGemmPluginProfiler(bool inference, bool skip = false) { + auto profiler = std::make_shared(); + profiler->setSkip(skip); + // If the profiler is created during the engine build, + // mMNKProfileMap is shared between different profilers to minimize the time spent on the profiling + // and do not repeat profiling for the GEMMs of the same shape. + if (!inference) { + profiler->setSelectionTactics(mMNKProfileMap); + } + return profiler; + } + + private: + MNKProfileMapPtr mMNKProfileMap{}; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py new file mode 100644 index 0000000000000..678102c809b63 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py @@ -0,0 +1,397 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generate fpA intB GEMM kernels: +# pip install nvidia-cutlass +# python generate_kernels.py -a "90" -o ./fpA_intB_gemm/launchers + +import argparse +import enum +import os +from itertools import product + +from cutlass_library import ( + DataType, + DataTypeNames, + DataTypeSize, + DataTypeTag, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + GemmKind, + GemmKindNames, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, +) + + +################################################################################ +# Epilogue Tag enum and string utils +class LlmEpilogueTag(enum.Enum): + epilogue_op_default = enum.auto() + epilogue_op_bias = enum.auto() + epilogue_op_silu = enum.auto() + epilogue_op_gelu = enum.auto() + + +class LlmEpilogueFusion(enum.Enum): + epilogue_fusion_none = enum.auto() + epilogue_fusion_finalize = enum.auto() + + +EpiTagNames = { + LlmEpilogueTag.epilogue_op_default: "lc", # linear combination + LlmEpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition + LlmEpilogueTag.epilogue_op_silu: "silu", # silu or swiglu + LlmEpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu +} + +EpiTag = { + LlmEpilogueTag.epilogue_op_default: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefault", + LlmEpilogueTag.epilogue_op_bias: "onnxruntime::llm::cutlass_extensions::EpilogueOpBias", + LlmEpilogueTag.epilogue_op_silu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultSilu", + LlmEpilogueTag.epilogue_op_gelu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultFtGelu", +} + +EpiFusion = { + LlmEpilogueFusion.epilogue_fusion_none: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE", +} + +EpiFusionSuffixes = { + None: "", + LlmEpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE", +} + + +################################################################################ +# Quantization Operation and string utils +class LlmQuantOp(enum.Enum): + per_column_scale_only = enum.auto() + finegrained_scale_only = enum.auto() + finegrained_scale_and_zeros = enum.auto() + none = enum.auto() + + +QuantOpNames = { + LlmQuantOp.per_column_scale_only: "cs", + LlmQuantOp.finegrained_scale_only: "fgs", + LlmQuantOp.finegrained_scale_and_zeros: "fgsz", + LlmQuantOp.none: "noquant", +} + +QuantOpTag = { + LlmQuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY", + LlmQuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY", + LlmQuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS", + LlmQuantOp.none: "void", +} + +################################################################################ +# The activations, biases, scales and zeros are instantiated using CUDA types, +# not CUTLASS types. This map materializes the name of the CUDA type. + + +def get_data_type_bits(type): + return DataTypeSize[type] + + +def get_data_type_names(type): + return DataTypeNames[type] + + +CudaTypeName = { + DataType.e4m3: "__nv_fp8_e4m3", + DataType.bf16: "__nv_bfloat16", + DataType.f16: "half", + DataType.f32: "float", +} + + +################################################################################ +# A data structure holding all info to instantiate gemm launchers in TRT LLM. +class LlmGemmLauncher: + def __init__( + self, + gemm_kind, + arch, + act_type, + weight_type, + scalezero_type, + bias_type, + output_type, + quant_op, + epi_tag, + cta_shape, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + epi_fusion=None, + ): + self.gemm_kind = gemm_kind + self.arch = arch + self.act_type = act_type + self.weight_type = weight_type + self.scalezero_type = scalezero_type + self.bias_type = bias_type + self.output_type = output_type + self.quant_op = quant_op + self.epi_tag = epi_tag + self.cta_shape = cta_shape + self.warp_shape = warp_shape + self.stages = stages + self.cga_shape = cga_shape + self.mainloop_schedule = mainloop_schedule + self.epi_schedule = epi_schedule + self.epi_fusion = epi_fusion + + def __repr__(self): + kernel_prefix = f"{GemmKindNames[self.gemm_kind]}_sm{self.arch}_{get_data_type_names(self.act_type)}_{get_data_type_names(self.weight_type)}_{get_data_type_names(self.scalezero_type)}_{get_data_type_names(self.bias_type)}_{get_data_type_names(self.output_type)}_{QuantOpNames[self.quant_op]}_{EpiTagNames[self.epi_tag]}_{self.cta_shape[0]}x{self.cta_shape[1]}x{self.cta_shape[2]}_{self.warp_shape[0]}x{self.warp_shape[1]}x{self.warp_shape[2]}_{self.stages}" + + hopper_suffix = f"_{self.cga_shape[0]}x{self.cga_shape[1]}x{self.cga_shape[2]}{KernelScheduleSuffixes[self.mainloop_schedule]}{EpilogueScheduleSuffixes[self.epi_schedule]}{EpiFusionSuffixes[self.epi_fusion]}" + + if self.arch >= 90: + return kernel_prefix + hopper_suffix + elif self.arch > 100: + raise ValueError(f"SM{self.arch} not supported yet.") + return kernel_prefix + + +################################################################################ +def tuple_to_cute_shape(shape): + return f"cute::Shape, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>" + + +def instantiate_operation_tma_warp_specialized(operation): + act_tag = CudaTypeName[operation.act_type] + scale_zero_tag = CudaTypeName[operation.scalezero_type] + bias_tag = CudaTypeName[operation.bias_type] + out_tag = CudaTypeName[operation.output_type] + + quant_op = QuantOpTag[operation.quant_op] + epi_tag = EpiTag[operation.epi_tag] + + cute_cta_shape = tuple_to_cute_shape(operation.cta_shape) + cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) + + kernel_sched = KernelScheduleTag[operation.mainloop_schedule] + epi_sched = EpilogueScheduleTag[operation.epi_schedule] + + assert operation.gemm_kind == GemmKind.Gemm + weight_tag = DataTypeTag[operation.weight_type] + + return f""" +template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, +{quant_op}, {epi_tag}, +{cute_cta_shape}, {cute_cga_shape}, +{kernel_sched}, {epi_sched}> ( +const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, +{out_tag}*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); +""" + + +def instantiate_operation(insts_list, operation): + if operation.arch >= 90: + insts_list.append(instantiate_operation_tma_warp_specialized(operation)) + + +def get_file_content(launcher_inl_files, operations): + assert operations + include_list = list() + for file in launcher_inl_files: + include_list.append(f'#include "{file}"') + includes = "\n".join(include_list) + + insts_list = list() + for op in operations: + instantiate_operation(insts_list, op) + instantiations = "\n".join(insts_list) + + file_content = f"""{includes} +namespace onnxruntime::llm +{{ +namespace kernels +{{ +namespace cutlass_kernels +{{ + +{instantiations} + +}} // namespace cutlass_kernels +}} // namespace kernels +}} // namespace onnxruntime::llm +""" + return file_content + + +def write_file(launcher_inl_files, operations, output_file): + os.makedirs(os.path.dirname(output_file), exist_ok=True) + # Avoid changing modified time if file content is up to date + content = get_file_content(launcher_inl_files, operations) + if os.path.exists(output_file): + with open(output_file) as f: + if f.read() == content: + return + with open(output_file, mode="w") as f: + f.write(content) + + +def elementwise(x, y, f): + return tuple(f(a, b) for (a, b) in zip(x, y, strict=False)) + + +def is_gemm_op_valid(op): + tile_m, tile_n, _ = op.cta_shape + cga_m, cga_n, _ = op.cga_shape + + if cga_m == 1 and cga_n == 1: + return True + + if cga_m == 2 and cga_n == 1 and tile_m >= 128: + return True + + if cga_m == 1 and cga_n == 2 and tile_n >= 128: + return True + + if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128: + return True + + return False + + +################################################################################ +def generate_sm90_mixed_gemm_operations(enable_fp8=False, enable_scale_only=False): + arch = 90 + + # For legacy reasons, we use unsigned types for the weights. The instanitated template + # will remap those back to the signed type. + # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) + supported_dtypes = [ + (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), + (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), + (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16), + ] + + if enable_fp8: + supported_dtypes = [ + *supported_dtypes, + (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16), + ] + + quant_ops = [LlmQuantOp.finegrained_scale_and_zeros] + + if enable_scale_only: + quant_ops = [ + *quant_ops, + LlmQuantOp.finegrained_scale_only, + ] + + epi_tags = [LlmEpilogueTag.epilogue_op_bias] + + m_tiles = [64, 128] + n_tiles = [16, 32, 64, 128, 256] + cta_shapes_mn = product(m_tiles, n_tiles) + + warp_shape = [4, 1, 1] + stages = 0 # auto + + cga_shapes = product([1, 2], [1, 2], [1]) + + partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes) + + operations = list() + for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: + max_k_bits = 128 * 8 + cta_shape_k = max_k_bits // get_data_type_bits(dtype_combo[0]) + cta_shape_mnk = (*cta_shape_mn, cta_shape_k) + + use_coop = cta_shape_mn[0] == 128 + mainloop_schedule = ( + KernelScheduleType.TmaWarpSpecializedCooperative + if use_coop + else KernelScheduleType.TmaWarpSpecializedPingpong + ) + epi_schedule = ( + EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized + ) + + mixed_gemm_operation = LlmGemmLauncher( + GemmKind.Gemm, + arch, + *dtype_combo, + quant_op, + epi_tag, + cta_shape_mnk, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + ) + + if is_gemm_op_valid(mixed_gemm_operation): + operations.append(mixed_gemm_operation) + + return operations + + +def generate_sm90_operations(is_arch_enabled): + operations = generate_sm90_mixed_gemm_operations() + return operations + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Print the output directory") + + parser.add_argument("-o", "--output_dir", type=str, required=True, help="Path to the output directory") + parser.add_argument("-a", "--architectures", type=str, required=True, help="Architectures to generate kernels for") + + args = parser.parse_args() + + arches = args.architectures.split(";") + + output_dir = os.path.abspath(args.output_dir) + + include_map = { + (GemmKind.Gemm, 90): ["contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"], + } + + def has_arch(sm): + return f"{sm}" in arches or f"{sm}-real" in arches + + # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. + # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. + operations = [] + operations += generate_sm90_operations(has_arch(90)) + + op_groups = dict() + for op in operations: + dict_key = (op.gemm_kind, op.arch, op.cta_shape[0]) + op_group = op_groups.get(dict_key, list()) + op_group.append(op) + op_groups[dict_key] = op_group + + file_counter = 1 + for key, value in op_groups.items(): + gemm_kind, _, _ = key + out_file = os.path.join(output_dir, f"fpA_intB_gemm_launcher_{file_counter}.generated.cu") + write_file(include_map[key[:2]], value, out_file) + file_counter += 1 diff --git a/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h new file mode 100644 index 0000000000000..52e8eb225c79c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This is corresponding to nvinfer1 namespace used by TensorRT. Add it to avoid dependency on TensorRT. +namespace onnxruntime::llm::nvinfer { + +enum class DataType : int32_t { + //! 32-bit floating point format. + kFLOAT = 0, + + //! IEEE 16-bit floating-point format -- has a 5 bit exponent and 11 bit significand. + kHALF = 1, + + //! Signed 8-bit integer representing a quantized floating-point value. + kINT8 = 2, + + //! Signed 32-bit integer format. + kINT32 = 3, + + //! 8-bit boolean. 0 = false, 1 = true, other values undefined. + kBOOL = 4, + + //! Unsigned 8-bit integer format. + //! Cannot be used to represent quantized floating-point values. + kUINT8 = 5, + + //! Signed 8-bit floating point with + //! 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7. + kFP8 = 6, + + //! Brain float -- has an 8 bit exponent and 8 bit significand. + kBF16 = 7, + + //! Signed 64-bit integer type. + kINT64 = 8, + + //! Signed 4-bit integer type. + kINT4 = 9, + + kFP4 = 10, +}; +} // namespace onnxruntime::llm::nvinfer diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index ed6021530018f..3f485f0abdcb1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -9,21 +9,231 @@ #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" -#include "matmul_nbits.cuh" -#include "dequantize_blockwise.cuh" + +constexpr int MatMulNBits_Input_B = 1; +constexpr int MatMulNBits_Input_Scale = 2; +constexpr int MatMulNBits_Input_ZeroPoint = 3; namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using onnxruntime::llm::kernels::weight_only::WeightTypeId; +static GemmPluginProfilerManager s_profilerManager; + +template +void MatMulNBits::InitGemmProfiler(int sm) { + gemmProfiler_ = s_profilerManager.createGemmPluginProfiler(/*inference*/ false); + + if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } else if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } + + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + gemmProfiler_->setCudaKernelType(cuda_kernel_type, sm); + gemmProfiler_->setQuant(nbits_, has_bias_, has_zero_points_); + gemmProfiler_->setGroupSize(block_size_); +} + +template +void MatMulNBits::RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m) { + // Number of 16-bit elements after casting int8/int4 to fp16. + int n_16b = N_ / (nbits_ == 8 ? 2 : 4); + + gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kHALF); + + GemmDims dims = {min_m, max_m, n_16b, K_}; + gemmProfiler_->profileTactics(weightOnlyGemmRunner_, gemmId_.dtype, dims, gemmId_, hasWeightOnlyCudaKernel); +} + +template +Status MatMulNBits::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + if (has_fpA_intB_gemm_) { + cudaStream_t stream = cudaStreamLegacy; // Use default stream for prepacking. + if (input_idx == MatMulNBits_Input_B) { + ORT_RETURN_IF_ERROR(PrePack_B(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_Scale) { + ORT_RETURN_IF_ERROR(PrePack_Scale(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_ZeroPoint) { + if (has_zero_points_) { + ORT_RETURN_IF_ERROR(PrePack_ZeroPoint(tensor, alloc, stream)); + is_packed = true; + } + } + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_B([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t packed_weight_bytes = n * k / (8 / nbits_); + + // uint8 does not need to be packed so we do not need to allocate extra space. + IAllocatorUniquePtr packed_transposed_weight_space = this->GetTransientScratchBuffer(packed_weight_bytes); + int8_t* packed_transposed_weight = reinterpret_cast(packed_transposed_weight_space.get()); + + fpA_intB_weight_buffer_ = IAllocator::MakeUniquePtr(alloc, packed_weight_bytes, true); // Transient buffer. + + int8_t* preprocessed_weight = reinterpret_cast(fpA_intB_weight_buffer_.get()); + + const uint8_t* blob_data = tensor.Data(); + if (nbits_ == 4) { + // Transpose the weight and add default zero point. + onnxruntime::llm::kernels::fpA_intB_gemv::unpack_uint4_transposed_to_int8_direct_cuda( + stream, packed_transposed_weight, blob_data, n, k); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::transpose_uint8_matrix_and_convert_to_int8( + stream, packed_transposed_weight, blob_data, n, k); + } + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + auto tranpose_weight_buffer = this->AllocateBufferOnCPUPinned(packed_weight_bytes); + CUDA_RETURN_IF_ERROR(cudaMemcpy(tranpose_weight_buffer.get(), packed_transposed_weight, packed_weight_bytes, cudaMemcpyDeviceToHost)); + + auto processed_weight_buffer = this->AllocateBufferOnCPUPinned(n * k / (8 / nbits_)); + bool force_interleave = false; + + using onnxruntime::llm::kernels::cutlass_kernels::QuantType; + QuantType quant_type = nbits_ == 4 ? QuantType::W4_A16 : QuantType::W8_A16; + + // TODO: Add a cuda kernle for preprocessing so that we can avoid copying the data back to CPU. + onnxruntime::llm::kernels::cutlass_kernels::preprocess_weights_for_mixed_gemm( + reinterpret_cast(processed_weight_buffer.get()), + reinterpret_cast(tranpose_weight_buffer.get()), + {static_cast(k), static_cast(n)}, + quant_type, + force_interleave); + + CUDA_RETURN_IF_ERROR(cudaMemcpy(preprocessed_weight, processed_weight_buffer.get(), n * k / (8 / nbits_), cudaMemcpyHostToDevice)); + CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed transposed_weight in GPU", packed_transposed_weight, k, n * nbits_ / 8); + DUMP_TENSOR_D("preprocessed_weight", reinterpret_cast(preprocessed_weight), k, n * nbits_ / 8); + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + fpA_intB_scale_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + + typedef typename ToCudaType::MappedType CudaT; + CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + onnxruntime::llm::kernels::fpA_intB_gemv::launch_transpose_scale_kernel(stream, reinterpret_cast(tensor.Data()), transposed_scales, n, k_blocks); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("transposed_scales", transposed_scales, k_blocks, n); + } + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + typedef typename ToCudaType::MappedType CudaT; + const CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + fpA_intB_zero_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + CudaT* scaled_zero_points = reinterpret_cast(fpA_intB_zero_buffer_.get()); + + constexpr float kDefaultZeroPoint4Bit = 8.0f; + constexpr float kDefaultZeroPoint8Bit = 128.0f; + const float default_zero_point = nbits_ == 4 ? kDefaultZeroPoint4Bit : kDefaultZeroPoint8Bit; + const auto* zero_points_data = tensor.DataRaw(); + + // The scaled zero point will be zero for the default zero point, so there is no need to scale when it is nullptr. + if (!tensor.IsDataType()) { // zero point is uint8_t type + if (nbits_ == 4) { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + } else { // zero point is not uint8_t type + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("scaled_zero_points", scaled_zero_points, k_blocks, n); + } + return Status::OK(); +} template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const bool is_prepacked = has_fpA_intB_gemm_; const Tensor* a = ctx->Input(0); - const Tensor* b = ctx->Input(1); - const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* b = is_prepacked ? nullptr : ctx->Input(1); + const Tensor* scales = is_prepacked ? nullptr : ctx->Input(2); + const Tensor* zero_points = is_prepacked ? nullptr : ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); const Tensor* bias = ctx->Input(5); @@ -35,19 +245,17 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); const auto* a_data = a->Data(); - const uint8_t* blob_data = b->Data(); - const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const uint8_t* blob_data = is_prepacked ? nullptr : b->Data(); + const auto* scales_data = is_prepacked ? nullptr : scales->Data(); + const auto* zero_points_data = (is_prepacked || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - - typedef typename ToCudaType::MappedType CudaT; + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); constexpr bool transa = false; constexpr bool transb = true; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); - ORT_RETURN_IF_ERROR( - helper.Compute(a->Shape(), b_shape, transa, transb)); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); Tensor* Y = ctx->Output(0, helper.OutputShape()); @@ -55,6 +263,61 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); + cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); + + typedef typename ToCudaType::MappedType CudaT; + CudaT* out_data = reinterpret_cast(Y->MutableData()); + + int m = SafeInt(helper.M()); + int n = SafeInt(helper.N()); + int k = SafeInt(helper.K()); + + DUMP_TENSOR_INIT(); + + if constexpr (std::is_same::value) { + if (has_fpA_intB_gemm_) { + auto const& bestTactic = gemmProfiler_->getBestConfig(m, gemmId_); + + DUMP_STRING("Best tactic: m=", m, " n=", n, " k=", k, " group_size=", block_size_, bestTactic->toString()); + + if (bestTactic->enableCudaKernel) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + a_data, pre_quant_scale_ptr, fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, out_data, + alpha, m, n, k, block_size_, cuda_kernel_type, apply_alpha_in_advance); + + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(sm_, params, stream); + } else { + const size_t workspace_size = weightOnlyGemmRunner_->getWorkspaceSize(m, n, k); + auto workspace_buffer = GetScratchBuffer(workspace_size, ctx->GetComputeStream()); + + weightOnlyGemmRunner_->gemm( + a_data, + fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), + has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, + 1.f, + out_data, + m, n, k, + block_size_, + *bestTactic, + reinterpret_cast(workspace_buffer.get()), + workspace_size, + stream); + } + + return Status::OK(); + } + } + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { bool done = (nbits_ == 8) ? TryMatMul8Bits( reinterpret_cast(Y->MutableData()), @@ -62,24 +325,24 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())) + stream) : TryMatMul4Bits( reinterpret_cast(Y->MutableData()), reinterpret_cast(a_data), blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())); + stream); if (done) { return Status::OK(); } @@ -104,7 +367,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize8Bits( reinterpret_cast(b_data), @@ -115,7 +378,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block ORT_RETURN_IF_ERROR(DequantizeBlockwise8b( @@ -127,7 +390,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // 4 bits if (column_wise_quant_blk_) { @@ -145,7 +408,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), @@ -156,7 +419,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block @@ -171,11 +434,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } - DUMP_TENSOR_INIT(); DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); const CudaT alpha = ToCudaType::FromFloat(1.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index f5c2c6c4e4fdf..02740d905c7c7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -10,11 +10,27 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner; +using onnxruntime::llm::kernels::weight_only::GemmDims; +using onnxruntime::llm::kernels::weight_only::GemmIdCore; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using GemmProfilerPtr = std::shared_ptr; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; + +// Environment variable to configure fpA_intB_gemm for experiments. Set it to 0 to disable, 1 to eanble all. +constexpr const char* kFpAIntBGemmOption = "ORT_FPA_INTB_GEMM"; +constexpr int kFpAIntBGemmOption_All = 0x01; +constexpr int kFpAIntBGemmOption_Gemv = 0x02; +constexpr int kFpAIntBGemmOption_Int4 = 0x04; +constexpr int kFpAIntBGemmOption_Int8 = 0x08; template class MatMulNBits final : public CudaKernel { @@ -24,16 +40,91 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + + constexpr size_t kInputIndexScale = 2; + constexpr size_t kInputIndexZeroPoints = 3; + constexpr size_t kInputIndexGroupIndex = 4; + constexpr size_t kInputIndexBias = 5; + + has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); + has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); + has_bias_ = info.GetInputCount() > kInputIndexBias && info.node().InputDefs()[kInputIndexBias]->Exists(); + sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; + + if (has_zero_points_) { + int32_t zero_point_type = info.node().InputDefs()[kInputIndexZeroPoints]->TypeAsProto()->tensor_type().elem_type(); + int32_t scale_type = info.node().InputDefs()[kInputIndexScale]->TypeAsProto()->tensor_type().elem_type(); + is_zero_points_scale_same_type_ = (zero_point_type == scale_type); + } + + if constexpr (std::is_same::value) { + int option = ParseEnvironmentVariableWithDefault(kFpAIntBGemmOption, 0); + if ((option & (static_cast(nbits_) | kFpAIntBGemmOption_All)) != 0 && + (block_size_ == 64 || block_size_ == 128) && + (nbits_ == 4 || nbits_ == 8) && + !has_g_idx_ && has_zero_points_ && !has_bias_ && + N_ % (nbits_ == 8 ? 32 : 64) == 0 && + K_ % block_size_ == 0 && + sm_ >= 75) { + if ((option & (kFpAIntBGemmOption_Gemv | kFpAIntBGemmOption_All)) != 0) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + if (onnxruntime::llm::kernels::fpA_intB_gemv::is_supported(sm_, cuda_kernel_type)) { + has_fpA_intB_gemv_ = true; + } + } + + InitGemmProfiler(sm_); + + constexpr int max_m = 8291; + RunGemmProfile(has_fpA_intB_gemv_, 1, max_m); + has_fpA_intB_gemm_ = true; + } + } + +#ifndef NDEBUG + printf("n=%d, k=%d, block_size=%d, bits=%d, zp_bits=%d, g_idx=%d, bias=%d, gemv=%d, gemm=%d\n", + int(N_), int(K_), int(block_size_), int(nbits_), + has_zero_points_ ? (is_zero_points_scale_same_type_ ? int(sizeof(T)) * 8 : int(nbits_)) : int(0), + int(has_g_idx_ ? 1 : 0), int(has_bias_ ? 1 : 0), + int(has_fpA_intB_gemv_), int(has_fpA_intB_gemm_)); +#endif } Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: + void InitGemmProfiler(int sm); + void RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m); + + Status PrePack_B(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_Scale(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_ZeroPoint(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + int64_t K_; int64_t N_; int64_t block_size_; int64_t nbits_; + int sm_{0}; bool column_wise_quant_blk_{true}; + + bool has_g_idx_{false}; + bool has_bias_{false}; + bool has_zero_points_{false}; + bool is_zero_points_scale_same_type_{false}; + bool has_fpA_intB_gemv_{false}; + bool has_fpA_intB_gemm_{false}; + + WeightOnlyGemmRunnerPtr weightOnlyGemmRunner_{nullptr}; + mutable GemmProfilerPtr gemmProfiler_{nullptr}; + GemmIdCore gemmId_{}; + + IAllocatorUniquePtr fpA_intB_weight_buffer_; + IAllocatorUniquePtr fpA_intB_scale_buffer_; + IAllocatorUniquePtr fpA_intB_zero_buffer_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 2b2b726e62c79..63e2ab8e9cb9b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -15,30 +15,30 @@ std::conditional_t CudaCall( ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, const char* file, const int line); -#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL2(expr, m) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) +#define CUSPARSE_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL2(expr, m) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) -#define CUFFT_CALL(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) -#define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL_THROW(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL_THROW(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL_THROW(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUSPARSE_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) // the cudnn configuration call that doesn't need set stream -#define CUDNN_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUFFT_CALL_THROW(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) #ifdef ORT_USE_NCCL -#define NCCL_CALL(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) -#define NCCL_CALL_THROW(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) #endif } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 043f7ed57b8b0..f8739b859bef5 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -20,6 +20,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -486,6 +487,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura if (use_float16) { opts.output_abs_error = fp16_abs_error; + opts.output_rel_error = use_float16 ? 0.001f : 0.0005f; } std::vector> execution_providers; @@ -548,11 +550,8 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; -#elif USE_WEBGPU - // Use absolute error of 0.1 for WebGPU with subgroup implementation - float abs_error = 0.1f; #else - float abs_error = 0.05f; + float abs_error = 0.1f; #endif for (auto block_size : {16, 32, 64, 128}) { @@ -564,6 +563,53 @@ TEST(MatMulNBits, Float16Large) { } } +#ifdef USE_CUDA +TEST(MatMulNBits, Fp16_Int4_Int4ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_Fp16ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = false; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_NoZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = false; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} +#endif + #endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 63677094b1b4b..39f6958d47a12 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -21,6 +21,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -206,28 +207,26 @@ void RunTest8Bits(const TestOptions8Bits& opts) { } template -void TestMatMul8BitsTyped() { +void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { TestOptions8Bits base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if (base_opts.accuracy_level == 4) { - base_opts.output_abs_error = 0.1f; - base_opts.output_rel_error = 0.02f; - } else if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.055f; - base_opts.output_rel_error = 0.02f; - } + base_opts.output_abs_error = abs_error; + base_opts.output_rel_error = rel_error; { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; + opts.has_bias = false; RunTest8Bits(opts); } { TestOptions8Bits opts = base_opts; opts.has_zero_point = true; + opts.has_bias = false; RunTest8Bits(opts); } @@ -235,6 +234,7 @@ void TestMatMul8BitsTyped() { #if !defined(USE_CUDA) && !defined(USE_WEBGPU) { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; opts.has_bias = true; RunTest8Bits(opts); } @@ -249,7 +249,7 @@ void TestMatMul8BitsTyped() { } } // namespace -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { +TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -285,9 +285,25 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { } #if defined(USE_CUDA) || defined(USE_WEBGPU) -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { - TestMatMul8BitsTyped(); - TestMatMul8BitsTyped(); +TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { + constexpr float abs_error = 0.055f; + constexpr float rel_error = 0.02f; + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); +} +#endif + +#if defined(USE_CUDA) +TEST(MatMulNBits, Fp16_Int8_Cuda) { + constexpr float abs_error = 0.5f; + constexpr float rel_error = 0.05f; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); } #endif From c7f86b3665bd5cfd41b7d626213d2ef85227f245 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 30 May 2025 13:13:10 -0700 Subject: [PATCH 53/57] Built-in onnxruntime-extensions updates (#24892) ### Description - Fix onnxruntime-extensions include path. - Add option to onnxruntime_perf_test to register custom ops from a built-in onnxruntime-extensions. ### Motivation and Context Fix build.py `--use_extensions` option. Make it simple to use the built-in onnxruntime-extensions with onnxruntime_perf_test. --- cmake/external/extensions.cmake | 2 +- onnxruntime/test/perftest/command_args_parser.cc | 8 +++++++- onnxruntime/test/perftest/ort_test_session.cc | 4 ++++ onnxruntime/test/perftest/test_configuration.h | 1 + 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/cmake/external/extensions.cmake b/cmake/external/extensions.cmake index 8c00c1c8a530b..bd3c47d53f53d 100644 --- a/cmake/external/extensions.cmake +++ b/cmake/external/extensions.cmake @@ -69,7 +69,7 @@ set_target_properties(ortcustomops PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "") # target library or executable are defined in CMakeLists.txt of onnxruntime-extensions target_include_directories(ocos_operators PRIVATE ${RE2_INCLUDE_DIR} ${json_SOURCE_DIR}/include) -target_include_directories(ortcustomops PUBLIC $) +target_include_directories(ortcustomops PUBLIC $) if(OCOS_ENABLE_SPM_TOKENIZER) onnxruntime_add_include_to_target(sentencepiece-static ${PROTOBUF_LIB} ${ABSEIL_LIBS}) endif() diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b63ef7959e1db..d409032b4ebb3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -161,6 +161,9 @@ namespace perftest { "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-l Provide file as binary in memory by using fopen before session creation.\n" "\t-R [Register custom op]: allow user to register custom op by .so or .dll file.\n" + "\t-X [Enable onnxruntime-extensions custom ops]: Registers custom ops from onnxruntime-extensions. " + "onnxruntime-extensions must have been built in to onnxruntime. This can be done with the build.py " + "'--use_extensions' option.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -190,7 +193,7 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlgR:"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlgR:X"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -393,6 +396,9 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, case 'g': test_config.run_config.enable_cuda_io_binding = true; break; + case 'X': + test_config.run_config.use_extensions = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 46e167b2ef823..05136ec0750a1 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -826,6 +826,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } + if (performance_test_config.run_config.use_extensions) { + session_options.EnableOrtCustomOps(); + } + if (!performance_test_config.model_info.load_via_path) { session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); } else { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index e180efca5b9db..8145f5f35c3b3 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -67,6 +67,7 @@ struct RunConfig { bool exit_after_session_creation = false; std::basic_string register_custom_op_path; bool enable_cuda_io_binding{false}; + bool use_extensions = false; }; struct PerformanceTestConfig { From 5e36544debfea08512f8ce55e50986ae945c2caa Mon Sep 17 00:00:00 2001 From: omarhass47 <103273844+omarhass47@users.noreply.github.com> Date: Fri, 30 May 2025 20:09:58 -0700 Subject: [PATCH 54/57] Add WAITPKG checks, add support for TPAUSE within SpinPause (#24524) ### Description This change introduces `TPAUSE` support in the `SpinPause()` function in Windows and Linux to reduce power consumption and improve efficiency during spin-wait periods. `TPAUSE` is a lightweight power/performance ISA that goes into an optimized C0 power state while waiting on a delay event, compared to `_mm_pause()` which is a NOP-like instruction that provides a small delay in the CPU Pipeline. With this change, performance of First Inference Latency across certain models can also improve. Models that were tested internally have shown up to ~2x improvement in First Inference Latency and up to ~20% lower overall power consumption. Genuine Intel CPUID detection logic was also refactored into a shared utility (`CheckIntel()`), enabling consistent platform checks across the codebase. Here `TPAUSE` is enabled by default for architectures that support it. [Intel Intrinsics Guide (TPAUSE)](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=tpause&techs=MMX,SSE_ALL,AVX_ALL,AVX_512,AMX,SVML,Other&ig_expand=6888,6888) ### Motivation and Context Performance and power efficiency gains - Previous PR was created which initially introduced the TPAUSE instruction in `SpinPause()` with measured improvements in power (please see previous TPAUSE PR here: [Add WAITPKG checks, add support for TPAUSE in ThreadPool spin #16935](https://github.com/microsoft/onnxruntime/pull/16935)). Additional performance testing and measurements were done across Mobile, Desktop, and Server, influencing enhancements to the PR such as a tweak to the `spin_delay_cycles`, Linux support and the refactored Intel CPUID detection logic. --- cmake/CMakeLists.txt | 1 - cmake/onnxruntime_common.cmake | 10 ++ include/onnxruntime/core/common/spin_pause.h | 17 +--- onnxruntime/core/common/cpuid_info.cc | 16 ++- onnxruntime/core/common/cpuid_info.h | 2 + onnxruntime/core/common/spin_pause.cc | 48 +++++++++ onnxruntime/core/platform/check_intel.cc | 97 +++++++++++++++++++ onnxruntime/core/platform/check_intel.h | 13 +++ .../windows/hardware_core_enumerator.cc | 28 +----- 9 files changed, 191 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/core/common/spin_pause.cc create mode 100644 onnxruntime/core/platform/check_intel.cc create mode 100644 onnxruntime/core/platform/check_intel.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 301fb0fbe82b0..416ed5e49f25a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -258,7 +258,6 @@ option(onnxruntime_USE_OPENVINO_INTERFACE "Build ONNXRuntime shared lib which is option(onnxruntime_USE_VITISAI_INTERFACE "Build ONNXRuntime shared lib which is compatible with Vitis-AI EP interface" OFF) option(onnxruntime_USE_QNN_INTERFACE "Build ONNXRuntime shared lib which is compatible with QNN EP interface" OFF) - if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 11.1) message(FATAL_ERROR "GCC version must be greater than or equal to 11.1") endif() diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index f9cd35fa71aa8..e629df4843109 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -11,6 +11,8 @@ set(onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/common/logging/*.cc" "${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.h" "${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/check_intel.h" + "${ONNXRUNTIME_ROOT}/core/platform/check_intel.cc" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.h" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.cc" "${ONNXRUNTIME_ROOT}/core/platform/env.h" @@ -100,6 +102,14 @@ if(WIN32) target_compile_options(onnxruntime_common PRIVATE "/Zc:char8_t-") endif() endif() + +if(NOT WIN32 AND NOT APPLE AND NOT ANDROID AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + set_source_files_properties( + ${ONNXRUNTIME_ROOT}/core/common/spin_pause.cc + PROPERTIES COMPILE_FLAGS "-mwaitpkg" + ) +endif() + if (onnxruntime_USE_TELEMETRY) set_target_properties(onnxruntime_common PROPERTIES COMPILE_FLAGS "/FI${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows/TraceLoggingConfigPrivate.h") endif() diff --git a/include/onnxruntime/core/common/spin_pause.h b/include/onnxruntime/core/common/spin_pause.h index 49b71e5567d3e..4d987f1d12977 100644 --- a/include/onnxruntime/core/common/spin_pause.h +++ b/include/onnxruntime/core/common/spin_pause.h @@ -3,26 +3,11 @@ #pragma once -#if defined(_M_AMD64) -#include -#endif - -#if defined(__x86_64__) -#include -#endif - namespace onnxruntime { - namespace concurrency { // Intrinsic to use in spin-loops - -inline void SpinPause() { -#if defined(_M_AMD64) || defined(__x86_64__) - _mm_pause(); -#endif -} +void SpinPause(); } // namespace concurrency - } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 91961bf22ce1e..00ff896bf6749 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -3,9 +3,12 @@ #include "core/common/cpuid_info.h" #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/platform/check_intel.h" #ifdef __linux__ - +#if (defined(_M_AMD64) || defined(__x86_64__)) && !defined(__ANDROID__) +#include +#endif #include #include #if !defined(__NR_getcpu) @@ -133,6 +136,17 @@ void CPUIDInfo::X86Init() { // avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31))); is_hybrid_ = (data[3] & (1 << 15)); + // Check for TPAUSE + CheckIntelResult check_intel = CheckIntel(); + if (check_intel.is_intel) { +#ifdef __linux__ +#if !defined(__ANDROID__) + has_tpause_ = __builtin_cpu_supports("waitpkg") != 0; +#endif +#else + has_tpause_ = (data[2] & (1 << 5)) != 0; +#endif + } if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index b820fa2ab1af7..9c67ebbffa260 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -33,6 +33,7 @@ class CPUIDInfo { bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } + bool HasTPAUSE() const { return has_tpause_; } // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } @@ -112,6 +113,7 @@ class CPUIDInfo { bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; + bool has_tpause_{false}; std::vector core_uarchs_; // micro-arch of each core diff --git a/onnxruntime/core/common/spin_pause.cc b/onnxruntime/core/common/spin_pause.cc new file mode 100644 index 0000000000000..9bada0841c162 --- /dev/null +++ b/onnxruntime/core/common/spin_pause.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/spin_pause.h" + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +#if defined(_M_AMD64) || defined(__x86_64__) +#include "core/common/cpuid_info.h" +#if defined(__linux__) +#include +#include +#endif +#endif + +namespace onnxruntime { +namespace concurrency { + +// Intrinsic to use in spin-loops +void SpinPause() { +#if (defined(_M_AMD64) || defined(__x86_64__)) && \ + !defined(__ANDROID__) && \ + !defined(__APPLE__) + + static const bool has_tpause = CPUIDInfo::GetCPUIDInfo().HasTPAUSE(); + static constexpr uint64_t tpause_spin_delay_cycles = 1000; + if (has_tpause) { +#if defined(_WIN32) + _tpause(0x0, __rdtsc() + tpause_spin_delay_cycles); +#elif defined(__linux__) + __builtin_ia32_tpause(0x0, __rdtsc() + tpause_spin_delay_cycles); +#else + _mm_pause(); +#endif + } else { + _mm_pause(); + } +#endif +} + +} // namespace concurrency +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/check_intel.cc b/onnxruntime/core/platform/check_intel.cc new file mode 100644 index 0000000000000..d773ae2d2be2f --- /dev/null +++ b/onnxruntime/core/platform/check_intel.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/check_intel.h" + +#if (defined(_M_AMD64) || defined(__x86_64__)) +#if defined(__linux__) +#include +#elif defined(_WIN32) +#include +#endif +#endif + +namespace onnxruntime { + +CheckIntelResult CheckIntel() { + CheckIntelResult intel_check = {false, false}; + bool is_intel = false; + bool is_intel_specified_platform = false; + +#if (defined(_M_AMD64) || defined(__x86_64__)) +#if defined(_WIN32) + constexpr unsigned int kVendorID_Intel[] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + constexpr unsigned int kVendorID_IntelSpecifiedPlatformIDs[] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + + int regs_leaf0[4]; + int regs_leaf1[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf1, 0x1); + + is_intel = + (kVendorID_Intel[0] == static_cast(regs_leaf0[1])) && + (kVendorID_Intel[1] == static_cast(regs_leaf0[2])) && + (kVendorID_Intel[2] == static_cast(regs_leaf0[3])); + + if (!is_intel) { + return intel_check; // if not an Intel CPU, return early + } + + for (auto intel_specified_platform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((static_cast(regs_leaf1[0]) >> 4) == intel_specified_platform) { + is_intel_specified_platform = true; + break; + } + } + +#elif defined(__linux__) + constexpr unsigned int kVendorID_Intel[] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + unsigned int regs[4] = {0}; + __get_cpuid(0, ®s[0], ®s[1], ®s[2], ®s[3]); + + is_intel = (regs[1] == kVendorID_Intel[0] && + regs[2] == kVendorID_Intel[1] && + regs[3] == kVendorID_Intel[2]); + if (!is_intel) { + return intel_check; // if not an Intel CPU, return early + } + + __get_cpuid(1, ®s[0], ®s[1], ®s[2], ®s[3]); + + unsigned int base_family = (regs[0] >> 8) & 0xF; + unsigned int base_model = (regs[0] >> 4) & 0xF; + unsigned int extended_model = (regs[0] >> 16) & 0xF; + + unsigned int model = + (base_family == 0x6 || base_family == 0xF) + ? (base_model + (extended_model << 4)) + : base_model; + + constexpr unsigned int kVendorID_IntelSpecifiedPlatformIDs[] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 170, // MTL (0xAA) + 197, // ARL-H (0xC5) + 198 // ARL-U (0xC6) + }; + + for (auto id : kVendorID_IntelSpecifiedPlatformIDs) { + if (model == id) { + is_intel_specified_platform = true; + break; + } + } +#endif //__linux__ +#endif // (_M_AMD64) || (__x86_64__) + + intel_check.is_intel = is_intel; + intel_check.is_intel_specified_platform = is_intel_specified_platform; + + return intel_check; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/check_intel.h b/onnxruntime/core/platform/check_intel.h new file mode 100644 index 0000000000000..1b82940489171 --- /dev/null +++ b/onnxruntime/core/platform/check_intel.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +typedef struct { + bool is_intel; + bool is_intel_specified_platform; +} CheckIntelResult; + +CheckIntelResult CheckIntel(); +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc index 7464ab4c57d01..40a2fb780878c 100644 --- a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -3,6 +3,7 @@ #include "hardware_core_enumerator.h" #include "core/platform/windows/env.h" +#include "core/platform/check_intel.h" #include #include #include @@ -85,30 +86,11 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetCoreInfo(); #if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) - const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" - bool isIntelSpecifiedPlatform = false; - const int kVendorID_IntelSpecifiedPlatformIDs[3] = { - // ExtendedModel, ExtendedFamily, Family Code, and Model Number - 0xa06a, // MTL - 0xc065, // ARL-H - 0xb065 // ARL-U - }; - - int regs_leaf0[4]; - int regs_leaf1[4]; - __cpuid(regs_leaf0, 0); - __cpuid(regs_leaf1, 0x1); - - auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); - - for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { - if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { - isIntelSpecifiedPlatform = true; - } - } - if (isIntel) { - if (isIntelSpecifiedPlatform) { + CheckIntelResult check_intel = CheckIntel(); + + if (check_intel.is_intel) { + if (check_intel.is_intel_specified_platform) { // We want to exclude cores without an LLC return cores.LLCCores; } else { From 70de20b94fdc77a0e603ef868a2c53ef38da557c Mon Sep 17 00:00:00 2001 From: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Date: Sat, 31 May 2025 12:03:50 +0530 Subject: [PATCH 55/57] [AIX] Not enabling ABSL_ENABLE_INSTALL for AIX to fix the build failure (#24910) ### Description Recent changes in abseil-cpp.cmake is enabling ABSL_ENABLE_INSTALL which is causing compilation error for AIX. But the same was working before, so blocking this enablement. ``` [ 83%] Linking CXX executable onnxruntime_perf_test ld: 0706-006 Cannot find or open library file: -l absl_failure_signal_handler ld:open(): A file or directory in the path name does not exist. ld: 0706-006 Cannot find or open library file: -l absl_examine_stack ld:open(): A file or directory in the path name does not exist. ld: 0706-006 Cannot find or open library file: -l absl_flags_parse ld:open(): A file or directory in the path name does not exist. ld: 0706-006 Cannot find or open library file: -l absl_flags_usage ld:open(): A file or directory in the path name does not exist. ld: 0706-006 Cannot find or open library file: -l absl_flags_usage_internal ld:open(): A file or directory in the path name does not exist. .ibm-clang: error: linker command failed with exit code 255 (use -v to see invocation) ``` ### Motivation and Context To fix the compilation error, blocking the enablement of ABSL_ENABLE_INSTALL under AIX. --- cmake/external/abseil-cpp.cmake | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 5cfb9e78b4720..488df5a4e0de8 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -15,7 +15,9 @@ set(ABSL_USE_EXTERNAL_GOOGLETEST ON) if (onnxruntime_USE_XNNPACK) set(ABSL_ENABLE_INSTALL OFF) else() - set(ABSL_ENABLE_INSTALL ON) + if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + set(ABSL_ENABLE_INSTALL ON) + endif() endif() if(Patch_FOUND AND WIN32) From 196dea1b236737aa04ad2e407c461382303e5b9d Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Sat, 31 May 2025 13:20:21 -0500 Subject: [PATCH 56/57] Update Whisper attention fusions (#24857) ### Description This PR updates the attention fusions for Whisper to work with the latest `transformers` package (`4.52.3`). ### Motivation and Context Previously, the attention fusions were maintained for many older `transformers` versions. The existing fusions do not work with the latest `transformers` versions. --- .../tools/transformers/convert_generation.py | 140 +- .../tools/transformers/fusion_attention.py | 4 +- .../transformers/fusion_bart_attention.py | 619 ++--- .../models/whisper/requirements.txt | 10 +- .../models/whisper/whisper_jump_times.py | 2 +- .../python/tools/transformers/onnx_model.py | 18 +- .../decoder_attention_with_sln_fused.onnx | Bin 68540 -> 0 bytes .../models/whisper/decoder_mha_fused.onnx | Bin 68649 -> 0 bytes .../whisper/decoder_mha_split_bias_fused.onnx | Bin 68712 -> 0 bytes .../decoder_with_past_cross_mha_fused.onnx | Bin 35663 -> 0 bytes ..._with_past_cross_mha_split_bias_fused.onnx | Bin 35397 -> 0 bytes .../decoder_with_past_self_mha_fused.onnx | Bin 69261 -> 0 bytes ...r_with_past_self_mha_split_bias_fused.onnx | Bin 69284 -> 0 bytes .../encoder_attention_with_sln_fused.onnx | Bin 68160 -> 0 bytes .../hf_fp16_decoder_attention_no_past.onnx | Bin 0 -> 5949 bytes ..._decoder_attention_no_past_split_bias.onnx | Bin 0 -> 6207 bytes .../hf_fp16_decoder_attention_with_past.onnx | Bin 0 -> 5582 bytes ...ecoder_attention_with_past_split_bias.onnx | Bin 0 -> 5882 bytes .../hf_fp16_encoder_self_attention.onnx | Bin 0 -> 3549 bytes .../hf_fp32_decoder_attention_no_past.onnx | Bin 0 -> 8193 bytes ..._decoder_attention_no_past_split_bias.onnx | Bin 0 -> 8430 bytes .../hf_fp32_decoder_attention_with_past.onnx | Bin 0 -> 7426 bytes ...ecoder_attention_with_past_split_bias.onnx | Bin 0 -> 7664 bytes .../hf_fp32_encoder_self_attention.onnx | Bin 0 -> 4685 bytes .../oai_fp16_decoder_attention_no_past.onnx | Bin 0 -> 5657 bytes ..._decoder_attention_no_past_split_bias.onnx | Bin 0 -> 5851 bytes .../oai_fp16_decoder_attention_with_past.onnx | Bin 0 -> 5225 bytes ...ecoder_attention_with_past_split_bias.onnx | Bin 0 -> 5425 bytes .../oai_fp16_encoder_self_attention.onnx | Bin 0 -> 3161 bytes .../oai_fp32_decoder_attention_no_past.onnx | Bin 0 -> 7901 bytes ..._decoder_attention_no_past_split_bias.onnx | Bin 0 -> 8074 bytes .../oai_fp32_decoder_attention_with_past.onnx | Bin 0 -> 7069 bytes ...ecoder_attention_with_past_split_bias.onnx | Bin 0 -> 7207 bytes .../oai_fp32_encoder_self_attention.onnx | Bin 0 -> 4524 bytes .../test/python/transformers/test_whisper.py | 991 ++++++-- .../transformers/whisper_model_generator.py | 2021 ----------------- 36 files changed, 1185 insertions(+), 2620 deletions(-) delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx delete mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_encoder_self_attention.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx delete mode 100644 onnxruntime/test/python/transformers/whisper_model_generator.py diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 8eb2afb3db896..ed89d00bdc069 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1447,7 +1447,7 @@ def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[ return model -def fix_past_sequence_length(model: ModelProto): +def fix_past_sequence_length(model: OnnxModel): # Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate # past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of # from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and @@ -1480,56 +1480,119 @@ def fix_past_sequence_length(model: ModelProto): # | # Add + # Constant names to be used + past_seq_len_name = "past_sequence_length" + past_seq_len_int32 = "past_seq_len_int32" + past_seq_len_int64 = "past_seq_len_int64" + node = list(filter(lambda n: n.op_type == "LayerNormalization", model.model.graph.node))[0] # noqa: RUF015 - base_path = model.match_parent_path( + base_path_hf = model.match_parent_path( + node, + ["Add", "Gather", "Tile", "Expand", "Unsqueeze", "Range"], + [0, 1, 1, 0, 0, 0], + ) + base_path_oai = model.match_parent_path( node, ["Add", "Slice"], [0, 1], ) - if base_path is None: + if base_path_hf is not None: + base_path = base_path_hf + elif base_path_oai is not None: + base_path = base_path_oai + else: + logger.info("Cannot identify base path for fixing past_sequence_length subgraph") return + base_node = base_path[-1] - left_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Add", "Gather", "Shape"], - [2, 0, 0, 0], - ) - right_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Gather", "Shape"], - [1, 0, 0], - ) - long_right_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"], - [1, 0, 0, 0, 0], - ) - if left_path is None or right_path is None or left_path[-2:] != right_path[-2:]: - return + if base_node.op_type == "Range": + # Hugging Face implementation + range_node = base_path[-1] + + gather_path = model.match_parent_path( + range_node, + ["Gather", "Shape"], + [0, 0], + ) + if gather_path is None: + logger.info("Cannot identify gather path for fixing past_sequence_length subgraph") + return + + add_path = model.match_parent_path( + range_node, + ["Add", "Gather", "Shape"], + [1, 0, 0], + ) + if add_path is None: + logger.info("Cannot identify add path for fixing past_sequence_length subgraph") + return + add_node = add_path[0] + + if gather_path != add_path[1:]: + logger.info("Gather path and add path do not share the same nodes for calculating the past_sequence_length") + return + + # Remove `past_key_self_0 --> Shape --> Gather` connection + constant_in_gather = list(filter(lambda n: n.output[0] == gather_path[0].input[1], model.model.graph.node))[0] # noqa: RUF015 + model.model.graph.node.remove(constant_in_gather) + model.model.graph.node.remove(gather_path[0]) + model.model.graph.node.remove(gather_path[1]) + + # Add `past_seq_len_int64` as an input name to existing nodes + range_node.input[0] = past_seq_len_int64 + add_node.input[0] = past_seq_len_int64 - # Remove `past_key_self_0 --> [Transpose --> Reshape] --> Shape --> Gather` connection - # where `Transpose --> Reshape` part may or may not exist. The OpenAI implementation of - # Whisper has an extra `Transpose --> Reshape` connection to remove. - constant_node = list(filter(lambda n: n.output[0] == left_path[-2].input[1], model.model.graph.node))[0] # noqa: RUF015 - model.model.graph.node.remove(left_path[-2]) - model.model.graph.node.remove(left_path[-1]) - model.model.graph.node.remove(constant_node) - if long_right_path is not None: - # Remove `Transpose --> Reshape` part - model.model.graph.node.remove(long_right_path[-2]) - model.model.graph.node.remove(long_right_path[-1]) + else: + # OpenAI implementation + input_ids_path = model.match_parent_path( + base_node, + ["Unsqueeze", "Add", "Gather", "Shape", "Reshape", "Transpose"], + [2, 0, 0, 0, 0, 0], + ) + if input_ids_path is None: + logger.info("Cannot identify input_ids path for fixing past_sequence_length subgraph") + return + add_node = input_ids_path[1] + + past_key_path = model.match_parent_path( + base_node, + ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) + if past_key_path is None: + logger.info("Cannot identify past_key path for fixing past_sequence_length subgraph") + return + unsqueeze_node = past_key_path[0] + + if input_ids_path[2:] != past_key_path[1:]: + logger.info( + "The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length" + ) + return + + # Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection + constant_in_gather = list(filter(lambda n: n.output[0] == past_key_path[1].input[1], model.model.graph.node))[0] # noqa: RUF015 + model.model.graph.node.remove(constant_in_gather) + constant_in_reshape = list(filter(lambda n: n.output[0] == past_key_path[-2].input[1], model.model.graph.node))[ # noqa: RUF015 + 0 + ] + model.model.graph.node.remove(constant_in_reshape) + model.model.graph.node.remove(past_key_path[1]) + model.model.graph.node.remove(past_key_path[2]) + model.model.graph.node.remove(past_key_path[3]) + model.model.graph.node.remove(past_key_path[4]) + + # Add `past_seq_len_int64` as an input name to existing nodes + unsqueeze_node.input[0] = past_seq_len_int64 + add_node.input[0] = past_seq_len_int64 # Add `past_sequence_length` as model input - past_seq_len_name = "past_sequence_length" model.model.graph.input.append( onnx.helper.make_tensor_value_info(past_seq_len_name, TensorProto.INT32, shape=[1]), ) # Add `past_sequence_length --> Squeeze --> Cast` connection - past_seq_len_int32 = "past_seq_len_int32" - past_seq_len_int64 = "past_seq_len_int64" - squeeze_node = onnx.helper.make_node( "Squeeze", inputs=[past_seq_len_name], @@ -1546,14 +1609,9 @@ def fix_past_sequence_length(model: ModelProto): ) cast_output = onnx.helper.make_tensor_value_info(past_seq_len_int64, TensorProto.INT64, shape=[]) - model.model.graph.value_info.extend([squeeze_output, cast_output]) - - # Add `past_seq_len_int64` as an input name to existing nodes - left_path[1].input[0] = past_seq_len_int64 - right_path[0].input[0] = past_seq_len_int64 - # Add new nodes to graph model.model.graph.node.extend([squeeze_node, cast_node]) + model.model.graph.value_info.extend([squeeze_output, cast_output]) model.topological_sort() return model, past_seq_len_name diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 5e1d491daae23..08f8691d8b2b5 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -663,12 +663,12 @@ def create_attention_node( first_input: str, output: str, add_qk_str: str = "", + causal: bool = False, past_k: str = "", past_v: str = "", present_k: str = "", present_v: str = "", scale: float | None = None, - causal: bool = False, ) -> NodeProto | None: """Create an Attention node. @@ -685,12 +685,12 @@ def create_attention_node( first_input (str): first input name output (str): output name add_qk_str (str): name of Add node after Q x K' + causal: whether it is uni-directional mask. past_k (str): name of input for past K value past_v (str): name of input for past V value present_k (str): name of output to store present K value present_v (str): name of output to store present V value scale: scale before softmax - causal: whether it is uni-directional mask. Returns: Union[NodeProto, None]: the node created or None if failed. diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 45bbfa94f6aa2..76dfeb76e4e8d 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -6,7 +6,7 @@ import numpy as np from fusion_attention import AttentionMask, FusionAttention -from onnx import TensorProto, helper +from onnx import helper from onnx_model import OnnxModel logger = logging.getLogger(__name__) @@ -26,115 +26,9 @@ def __init__( ): super().__init__(model, hidden_size, num_heads, attention_mask) - def check_runtime_shape_path( - self, - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, - ): - concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) - if concat_qkv_2_path is None: - return False - concat_qkv_2 = concat_qkv_2_path[0] - - reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) - reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) - if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None: - return False - - _, gather_1, shape_1 = reshape_qkv_2_path_1 - _, gather_2, shape_2 = reshape_qkv_2_path_2 - - if shape_1.input[0] != root_input or shape_2.input[0] != root_input: - return False - - reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) - reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) - if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: - return False - if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: - return False - - reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: - return False - - mul_q = reshape_q_2_path[-1] - mul_k = reshape_k_2_path[-1] - mul_v = reshape_v_2_path[-1] - - gather_1_out = gather_1.output[0] - if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: - return False - - return True - - def check_runtime_shape_path_openai( - self, - reshape_qkv_2, - matmul_qkv, - add_qk, - matmul_qk, - add_q, - ): - reshape_qkv_path = self.model.match_parent_path( - reshape_qkv_2, ["Concat", "Slice", "Shape", "Transpose"], [1, 0, 0, 0] - ) - if reshape_qkv_path is None or reshape_qkv_path[-1].input[0] != matmul_qkv.output[0]: - return False - - matmul_qk_path_1 = self.model.match_parent_path( - matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] - ) - matmul_qk_path_2 = self.model.match_parent_path( - matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] - ) - if matmul_qk_path_1 is None or matmul_qk_path_2 is None: - return False - - mul_1 = matmul_qk_path_1[0] - mul_2 = matmul_qk_path_2[0] - if mul_1.input[1] != mul_2.input[1]: - return False - if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: - return False - - # For decoder attentions only - if add_qk is not None: - add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) - if add_qk_path is None: - return False - slice_q_path_1 = self.model.match_parent_path( - add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] - ) - slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) - if slice_q_path_1 is None and slice_q_path_2 is None: - return False - _, unsqueeze_1, _, _ = slice_q_path_1 - unsqueeze_2, _, _ = slice_q_path_2 - if unsqueeze_1.input[0] != unsqueeze_2.input[0]: - return False - if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: - return False - - return True - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - # Track if fusion is occurring for OpenAI implementation of Whisper - model_impl_openai = False - # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( - normalize_node, - ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], - [1, 1, 0, 0, 0, 0], - ) - qkv_nodes_openai = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], @@ -143,32 +37,21 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ( add_out, matmul_out, - reshape_qkv_2, - transpose_qkv, - reshape_qkv_1, - matmul_qkv, - ) = qkv_nodes - elif qkv_nodes_openai is not None: - qkv_nodes = qkv_nodes_openai - ( - add_out, - matmul_out, - reshape_qkv_2, + reshape_qkv, transpose_qkv, matmul_qkv, ) = qkv_nodes - # Set model implementation to openai - model_impl_openai = True else: + logger.debug("fuse_attention: failed to match qkv path") return other_inputs = [] - for input in normalize_node.input: - if input not in output_name_to_node: + for input_ in normalize_node.input: + if input_ not in output_name_to_node: continue - if input == qkv_nodes[0].output[0]: + if input_ == qkv_nodes[0].output[0]: continue - other_inputs.append(input) + other_inputs.append(input_) if len(other_inputs) != 1: return root_input = other_inputs[0] @@ -185,9 +68,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization """ skip_layernorm = output_name_to_node[root_input] - # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose + # For some attention blocks, the end SkipLayerNormalization node may point to another node whose # child is the LayerNormalization node. - if skip_layernorm.op_type == "Add": + if skip_layernorm.op_type in {"Add", "Clip"}: skip_layernorm = self.model.get_children(skip_layernorm)[0] for output in skip_layernorm.output: if not output: @@ -201,304 +84,203 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): graph_input_names = {node.name for node in self.model.graph().input} graph_output_names = {node.name for node in self.model.graph().output} - v_nodes = self.model.match_parent_path( - matmul_qkv, - ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, None], - ) - v_nodes_openai = self.model.match_parent_path( + v_nodes_past_or_present = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None], ) - v_nodes_with_past_self_attn = self.model.match_parent_path( - # Decoder attention with past value concatenated before MatMul + v_nodes_with_past = self.model.match_parent_path( matmul_qkv, - ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 1, 0, 0, None], + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, None], ) - v_nodes_with_past_cross_attn = self.model.match_parent_path( - # Decoder attention with past value directly used in MatMul - matmul_qkv, - ["Reshape"], - [1], - ) - v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + v_nodes_past_only_oai = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0], ) past_v, present_v = "", "" - reshape_v_2, add_v = None, None - if v_nodes is not None: - (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - present_v = transpose_v.output[0] - elif v_nodes_openai is not None: - v_nodes = v_nodes_openai - (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - - # Find the child path to access the correct present_v values - # Openai impl provides present/past v values in 3D format - # whereas ort MultiHeadAttention expects v values in 4D, hence the - # additional Reshape and Transpose nodes are added - # For encoder attention types - # Add -> Reshape -> Transpose -> Present_V - reshape_path = self.model.match_child_path( - add_v, - ["Reshape", "Transpose"], - exclude=[reshape_v_1], - ) - # For decoder attention types - # add_v_node Reshape <- Transpose <-Past_V - # \ / - # \ / - # -> Concat <- - # | - # |--> Reshape -> Transpose -> Present_V - concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) - if reshape_path is not None: - (_, transpose_add_v) = reshape_path - if transpose_add_v.output[0] in graph_output_names: - present_v = transpose_add_v.output[0] - if concat_path is not None: - (concat_v, _, transpose_concat_v) = concat_path - if transpose_concat_v.output[0] in graph_output_names: - present_v = transpose_concat_v.output[0] - concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) - _, transpose_concat_v_in = concat_nodes - past_v = transpose_concat_v_in.input[0] - elif v_nodes_with_past_self_attn is not None: - (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn - v_nodes = v_nodes_with_past_self_attn + v_nodes, add_v, matmul_v = [], None, None + if v_nodes_past_or_present is not None: + v_nodes = v_nodes_past_or_present + (transpose_v, reshape_v, add_v, matmul_v) = v_nodes + + # Find past_v input name + start_child_nodes = input_name_to_nodes[add_v.output[0]] + for start_child_node in start_child_nodes: + if start_child_node.op_type == "Concat": + concat_v_nodes = self.model.match_parent_path( + start_child_node, + ["Reshape", "Transpose"], + [0, 0], + ) + if concat_v_nodes is not None: + past_v = concat_v_nodes[-1].input[0] + start_child_nodes = input_name_to_nodes[start_child_node.output[0]] + break + + # Find present_v output name + for start_child_node in start_child_nodes: + start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] + for start_grandchild_node in start_grandchild_nodes: + if start_grandchild_node.output[0] in graph_output_names: + present_v = start_grandchild_node.output[0] + break + if present_v != "": + break + elif v_nodes_with_past is not None: + v_nodes = v_nodes_with_past + (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes past_v = concat_v.input[0] present_v = concat_v.output[0] - elif ( - v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names - ): - v_nodes = v_nodes_with_past_cross_attn - past_v = v_nodes[-1].input[0] - present_v = v_nodes[-1].output[0] - if present_v not in graph_output_names: - identity_node_v = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) - ) - present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" - elif ( - v_nodes_with_past_cross_attn_openai is not None - and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names - ): - v_nodes = v_nodes_with_past_cross_attn_openai + elif matmul_qkv.input[1] in graph_input_names: + # Hugging Face's cross-attention where past_v is used directly as value + past_v = matmul_qkv.input[1] + elif v_nodes_past_only_oai is not None: + # OpenAI's cross-attention where past_v is used directly as value + v_nodes = v_nodes_past_only_oai past_v = v_nodes[-1].input[0] - present_v = v_nodes[-1].output[0] - if present_v not in graph_output_names: - identity_node_v = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) - ) - present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return past_v = past_v if past_v in graph_input_names else "" present_v = present_v if present_v in graph_output_names else "" - qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) - qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] - ) - qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) - add_qk = None - if qk_nodes_1 is not None: - _, matmul_qk = qk_nodes_1 - qk_nodes = qk_nodes_1 - elif qk_nodes_2 is not None: - _, _, add_qk, _, matmul_qk = qk_nodes_2 - qk_nodes = qk_nodes_2 - elif qk_nodes_2_openai is not None: - _, add_qk, matmul_qk = qk_nodes_2_openai - qk_nodes = qk_nodes_2_openai + qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + qk_nodes, add_qk = [], None + if qk_nodes_no_mask is not None: + _, matmul_qk = qk_nodes_no_mask + qk_nodes = qk_nodes_no_mask + elif qk_nodes_with_mask is not None: + _, add_qk, matmul_qk = qk_nodes_with_mask + qk_nodes = qk_nodes_with_mask else: + logger.debug("fuse_attention: failed to match qk path") return - q_nodes = self.model.match_parent_path( + q_nodes_hf = self.model.match_parent_path( matmul_qk, - ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], - [0, 0, 0, 0, 0, 1], + ["Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, 1], ) - q_nodes_openai = self.model.match_parent_path( + q_nodes_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, 1], ) - reshape_q_2 = None - if q_nodes is not None: - reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes - elif q_nodes_openai is not None: - q_nodes = q_nodes_openai - mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes + q_nodes = [] + if q_nodes_hf is not None: + q_nodes = q_nodes_hf + (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes + elif q_nodes_oai is not None: + q_nodes = q_nodes_oai + (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes else: + logger.debug("fuse_attention: failed to match q path") return - k_nodes_with_bias = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, 0, 1], - ) - k_nodes_no_bias_openai = self.model.match_parent_path( + k_nodes_no_past_hf = self.model.match_parent_path( matmul_qk, - ["Mul", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0], - ) - k_nodes_no_bias = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0, 0], + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], ) - k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path( - # Decoder attention with past key concatenated before MatMul + k_nodes_with_past_hf = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 0, 0], + ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0], ) - k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( - # Decoder attention with past key directly used in MatMul + k_nodes_past_or_present_oai = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape"], - [1, 0], + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], ) - k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( - # Decoder attention with past key directly used in MatMul + k_nodes_past_only_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0, 0], ) past_k, present_k = "", "" - reshape_k_2, reshape_k_1, matmul_k = None, None, None - if k_nodes_with_bias is not None: - _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias - k_nodes = k_nodes_with_bias - elif k_nodes_no_bias_openai is not None: - mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias_openai - k_nodes = k_nodes_no_bias_openai - present_k = matmul_k.output[0] - - # Find the child path to access the correct present_k values - # Openai impl provides present/past k values in 3D format - # whereas ort MultiHeadAttention expects k values in 4D, hence the - # additional Reshape and Transpose nodes are added - # For encoder attention types - # Matmul -> Reshape -> Transpose -> Present_K - reshape_path = self.model.match_child_path( - matmul_k, - ["Reshape", "Transpose"], - exclude=[reshape_k_1], - ) - # For decoder attention types - # matmul_k_node Reshape <- Transpose <- Past_K - # \ / - # \ / - # -> Concat <- - # | - # +--> Reshape -> Transpose -> Present_K - concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) - if reshape_path is not None: - (_, transpose_matmul_k) = reshape_path - if transpose_matmul_k.output[0] in graph_output_names: - present_k = transpose_matmul_k.output[0] - if concat_path is not None: - (concat_k, _, transpose_concat_k) = concat_path - if transpose_concat_k.output[0] in graph_output_names: - present_k = transpose_concat_k.output[0] - concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) - _, transpose_concat_k_in = concat_nodes - past_k = transpose_concat_k_in.input[0] - elif k_nodes_no_bias is not None: - _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias - k_nodes = k_nodes_no_bias - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - present_k = transpose_k_1.output[0] - elif k_nodes_no_bias_with_past_self_attn is not None: - _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn - k_nodes = k_nodes_no_bias_with_past_self_attn + k_nodes, add_k, matmul_k = [], None, None + if k_nodes_no_past_hf is not None: + k_nodes = k_nodes_no_past_hf + (transpose_k, reshape_k, matmul_k) = k_nodes + + # Find present_k output name + transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]] + for transpose_k_node in transpose_k_nodes: + if transpose_k_node.output[0] in graph_output_names: + present_k = transpose_k_node.output[0] + break + elif k_nodes_with_past_hf is not None: + k_nodes = k_nodes_with_past_hf + (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes past_k = concat_k.input[0] present_k = concat_k.output[0] - elif ( - k_nodes_no_bias_with_past_cross_attn is not None - and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names - ): - k_nodes = k_nodes_no_bias_with_past_cross_attn - past_k = k_nodes[-1].input[0] - present_k = k_nodes[-1].output[0] - if present_k not in graph_output_names: - identity_node_k = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) - ) - present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" - elif ( - k_nodes_no_bias_with_past_cross_attn_openai is not None - and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names - ): - k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names: + # Hugging Face's cross-attention where past_k is used directly as key + k_nodes = [output_name_to_node[matmul_qk.input[1]]] + past_k = k_nodes[0].input[0] + elif k_nodes_past_or_present_oai is not None: + k_nodes = k_nodes_past_or_present_oai + (_, transpose_k, reshape_k, matmul_k) = k_nodes + + # Find past_k input name + start_child_nodes = input_name_to_nodes[matmul_k.output[0]] + for start_child_node in start_child_nodes: + if start_child_node.op_type == "Concat": + concat_k_nodes = self.model.match_parent_path( + start_child_node, + ["Reshape", "Transpose"], + [0, 0], + ) + if concat_k_nodes is not None: + past_k = concat_k_nodes[-1].input[0] + start_child_nodes = input_name_to_nodes[start_child_node.output[0]] + break + + # Find present_k output name + for start_child_node in start_child_nodes: + start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] + for start_grandchild_node in start_grandchild_nodes: + if start_grandchild_node.output[0] in graph_output_names: + present_k = start_grandchild_node.output[0] + break + if present_k != "": + break + elif k_nodes_past_only_oai is not None: + # OpenAI's cross-attention where past_k is used directly as key + k_nodes = k_nodes_past_only_oai past_k = k_nodes[-1].input[0] - present_k = k_nodes[-1].output[0] - if present_k not in graph_output_names: - identity_node_k = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) - ) - present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: + logger.debug("fuse_attention: failed to match k path") return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if matmul_k is not None and add_k is None: # Create empty Add node for attention graph - bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] + add_v_tensor = self.model.get_initializer(add_v.input[0]) + bias_dim = add_v_tensor.dims[0] + dtype = add_v_tensor.data_type empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: self.add_initializer( empty_bias_name, - TensorProto.FLOAT, + dtype, dims=[bias_dim], - vals=np.array([0.0] * bias_dim, dtype=np.float32), + vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)), ) add_name = self.model.create_node_name("Add") - add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) + add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name) - if ( - model_impl_openai - and not bool(past_k) - and not self.check_runtime_shape_path_openai( - reshape_qkv_2, - matmul_qkv, - add_qk, - matmul_qk, - add_q, - ) - ): - return - elif ( - not model_impl_openai - and not bool(past_k) - and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, - ) - ): - return - - three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and "matmul_v" not in locals() + three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None one_root_input = ( not three_root_inputs - and matmul_k.input[0] == root_input and matmul_q.input[0] == root_input + and matmul_k.input[0] == root_input and matmul_v.input[0] == root_input ) two_root_inputs = ( @@ -509,84 +291,97 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) # There are 5 types of attention: - # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 - # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2 - # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value - # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 - # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 - encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) - decoder_attention_with_past = ( - (encoder_attention if not model_impl_openai else decoder_attention) and bool(past_k) and bool(past_v) - ) - decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 - decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 - - # For decoder_attention, the attention mask needs to be included in the attention node - mask_index, mask_nodes = None, [] - if decoder_attention: + # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask + # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask + # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask + # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value + # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask + encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask + decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask + decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask + decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v) + decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask + + # For decoder self-attentions, the attention mask needs to be included in the attention node + causal_mask = qk_nodes == qk_nodes_with_mask + mask_nodes = [] + if causal_mask: mask_nodes_bart = self.model.match_parent_path( add_qk, ["Where"], [1], ) - mask_nodes_whisper = self.model.match_parent_path( + mask_nodes_whisper_hf = self.model.match_parent_path( + add_qk, + ["Slice", "Expand", "Where"], + [1, 0, 1], + ) + mask_nodes_whisper_oai = self.model.match_parent_path( + add_qk, + ["Slice", "Unsqueeze", "Gather", "Shape", "Add"], + [1, 2, 0, 0, 0], + ) + mask_nodes_whisper_oai_unit_test = self.model.match_parent_path( add_qk, - ["Expand", "Unsqueeze", "Unsqueeze", "Where"], - [1, 0, 0, 0], + ["Slice", "Slice"], + [1, 0], ) - if mask_nodes_whisper is not None: - mask_index = mask_nodes_whisper[0].output[-1] - mask_nodes = mask_nodes_whisper + if mask_nodes_whisper_hf is not None: + mask_nodes = mask_nodes_whisper_hf + elif mask_nodes_whisper_oai is not None: + mask_nodes = mask_nodes_whisper_oai + elif mask_nodes_whisper_oai_unit_test is not None: + mask_nodes = mask_nodes_whisper_oai_unit_test elif mask_nodes_bart is not None: - mask_index = mask_nodes_bart[0].output[-1] mask_nodes = mask_nodes_bart + else: + logger.debug("fuse_attention: failed to match mask nodes") + return + assert len(mask_nodes) > 0 if ( encoder_attention - or decoder_attention - or decoder_attention_with_past + or decoder_self_attention or decoder_cross_attention + or decoder_self_attention_with_past or decoder_cross_attention_with_past ): - attention_last_node = reshape_qkv_2 - num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: logger.debug("fuse_attention: failed to detect num_heads or hidden_size") return new_node = None - if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: - # Note: Decoder attention with past key and past value is fused as multihead attention - # rather than attention because multihead attention supports separate past key and past + if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + # Note: Decoder attention with past key and past value is fused as multi-head attention + # rather than attention because multi-head attention supports separate past key and past # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( q_matmul=matmul_q, - k_matmul=matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, - v_matmul=matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k, + v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v, q_add=add_q, - k_add=add_k if decoder_cross_attention or decoder_attention_with_past else None, - v_add=add_v if decoder_cross_attention or decoder_attention_with_past else None, + k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None, + v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None, num_heads=num_heads, hidden_size=hidden_size, output=attention_last_node.output[0], - unidirectional=decoder_attention_with_past, - past_k=past_k if decoder_attention_with_past else "", - past_v=past_v if decoder_attention_with_past else "", + unidirectional=causal_mask, + past_k=past_k if decoder_self_attention_with_past else "", + past_v=past_v if decoder_self_attention_with_past else "", present_k=present_k, present_v=present_v, - packed_qkv=decoder_attention_with_past, ) if self.use_multi_head_attention else None ) else: - # Temporarily set multihead attention flag to false + # Temporarily set multi-head attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False - add_qk_str = mask_index if decoder_attention and mask_index else "" new_node = self.create_attention_node( mask_index=None, q_matmul=matmul_q, @@ -599,17 +394,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): hidden_size=hidden_size, first_input=root_input, output=attention_last_node.output[0], - add_qk_str=( - None if len(mask_nodes) > 1 else add_qk_str - ), # deprecate and use is_unidirectional attr instead for Whisper + causal=causal_mask, past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, - causal=decoder_attention, ) self.use_multi_head_attention = use_multi_head_attention_ground_truth if new_node is None: + logger.debug("fuse_attention: failed to create fused node") return self.nodes_to_add.append(new_node) @@ -618,22 +411,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) - # When using multihead attention, keep MatMul nodes in original graph - if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: - if q_nodes[-1].op_type == "MatMul": + # When using multi-head attention, keep MatMul nodes in original graph + if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul": q_nodes.pop() - if k_nodes[-1].op_type == "MatMul": + if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul": k_nodes.pop() - if v_nodes[-1].op_type == "MatMul": + if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul": v_nodes.pop() - if self.disable_multi_head_attention_bias and ( - decoder_cross_attention or decoder_cross_attention_with_past - ): - if q_nodes[-1].op_type == "Add": + if self.disable_multi_head_attention_bias: + if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add": q_nodes.pop() - if k_nodes[-1].op_type == "Add": + if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add": k_nodes.pop() - if v_nodes[-1].op_type == "Add": + if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add": v_nodes.pop() self.nodes_to_remove.extend(q_nodes) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 29a08b5ccd220..f1758cc52280f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,13 +1,13 @@ -torch>=1.13.0 -transformers>=4.36.0,<= 4.42.4 -openai-whisper>=20231117,<=20240927 +torch>=2.7.0 +transformers>=4.52.3 +openai-whisper==20240927 ffmpeg-python datasets soundfile librosa -optimum<=1.21.2 +optimum onnxruntime-extensions>=0.9.0 -onnx==1.17.0 +onnx protobuf==3.20.2 numpy==1.23.3 psutil diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index 4765616ec2b6f..a7c0d3538b8da 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -196,7 +196,7 @@ def create_torch_ops(self): # Set torch extensions directory to cache directory os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir - # Try to import `jinja` pip package + # Try to import `ninja` pip package try: assert torch.utils.cpp_extension.verify_ninja_availability() except Exception as e: diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 8add38b5a7d07..89c2d5e7cc259 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1349,6 +1349,8 @@ def has_same_value( tensor2: TensorProto, signature_cache1: dict | None = None, signature_cache2: dict | None = None, + rtol: float = 1e-05, + atol: float = 1e-08, ) -> bool: """Returns True when two tensors have same value. Note that name can be different. @@ -1358,6 +1360,8 @@ def has_same_value( tensor2 (TensorProto): initializer 2 signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. + rtol (float): Optional relative difference threshold for minor precision differences + atol (float): Optional absolute difference threshold for minor precision differences Returns: bool: True when two initializers has same value. """ @@ -1375,9 +1379,17 @@ def has_same_value( signature_cache1[tensor1.name] = sig1 if signature_cache2 is not None: signature_cache2[tensor2.name] = sig2 - if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: - # Same signature, now do the expensive check to confirm the data is the same - return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all() + if tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: + n1 = numpy_helper.to_array(tensor1) + n2 = numpy_helper.to_array(tensor2) + if sig1 == sig2: + # Same signature, now do the expensive check to confirm the data is the same + return (n1 == n2).all() + else: + # Check if tensors are allclose + from numpy import allclose + + return allclose(n1, n2, rtol=rtol, atol=atol) return False diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx deleted file mode 100644 index a0e65a002361288446b99030650531e6bb586d32..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 68540 zcmeI)O^e$`7{GDMyR4OQlX^wGyEudpLusObkxzl8kX>I2q3I#*sTVP}Cab}+q?Jbd zl9Rth$fX2w3gjd7u$O*|p8Gi}#om|L$tLs^@ZZLu8TpxcW}csP99!(Zd$lgkEYZ+$ zS{UEEyl6%8RAq6JP5fh}4?-huUG@yM%$?3Isr&g@rx*R5 zn;$j1Tfdcef2rpEwVYQ*`lU2FiQP@nP-(8CByPX|`^U#`ENl1XD4FyodXy!3^1_I} zEOAXbVr$8w`#~@$t(2SpztKAPjzgWN<-X0UTJ7Il=|Lvb(0yIE_gYqWTCdfn zPgeOiSgNV}ioKbc7%aS(-FH*V-kZUhTC1#vwXgPDgZfi{3}65Q7{CAqFn|FJU;qOc zzyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6( z00S7n00uCC0SsW^TQV@X{)7(qEypO#zz(iC6vTZyT`P}NIJ2t5IHqggQhmvV_YDRx zfB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n z00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO z0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD z3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAq zFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFo1#8H89xP zH%7%qC$aDOhvOsvSm}e%xPNZgXI9JVPB-9DTk+58J|30*!R^IIWT<8Cwp)!sL&t$S z@x7<^xsa>f{in7y?|UO$g*4Sy2b%dtl?D%96RKNNOS xn6FCm@x@>J&Li=#%C}~s>%B2Alh`py4njG3&F2hJW@eV8Pu6AeT zI43f%?aaz^jzs+wZH-Ba<%5 zo7?X^wm$!oy@8hyxH89VWHA&pw6*ZmaW{|{Rh+oDPM9^{X%#F=+jVg*E`~BJ3-cQrubt}h zf+_KK0^-)pjZ{x zcISJdd%73Sp6)gG-Z3r@gTOC0lkEO~{N5MGp~=&7-U2-%I}DXP%2XP*Usc_My3?9e zE0xK^b^iG)tJ!|R+{{e)ODB8#_le^k%-~F%bymaJ*Zb69{aJquU;qOczyJm?fB_6( z00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC z0SsUO0~o*n1~7mD3}E0Y8SpQ@o1^_3*C@=u`WIXW(*Bt%71-8SX`YjYG@?ghXDX95N=fB_6(00S7n00uCC z0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n z1~7mD3}65Q7{CAqFn|FJTr~r}8*wi{1(#RNCHY_y0~o*n1~7mD3}65Q7{CAqFn|FJ zU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm? zfB_6(00S7nz|4T}j?eA|_;u5Lg>taqGsZ>u(Ipj5;tAP+Y_I9qNi*>l(~7X#V79HTcYY7Jo#2Mblgt@oypJy zfsSQvmDTwr?Rs~Gn-unAy6D|syn4_3M%*p4`D!VfU2iAXqoPbi%YlxMY&gxWw=WuL zrt@-~oo;8IWhc@n-fi(^8Rh0ul&ZUNHm-W_f!JGZJ63~2e?)74OVpI>4iu&FfJaguganw8cLf`cbP_3+=HR>hx=Sl6c8w^8-{i`SN?DmD4-(NO7TY{lL4 zmYSBCUpl`|`gUAylc*zYr{DIM^3S7p^=B>Jowb*?<}S-0e4qDzJAUrV{JFBNUrVbK zQ`r&gs+;L{Vs?N3>$9hC9OrPWo^*nauBSQS}q=j&nh*Ge{Ktrbt) zG5t;y=Z)F^>KkWyk%AO&2M`Y@A|rWha~x;Uw9!_Nb*a?fzt~;y*tyU%Bl3?UF0B*S zvlrLyEm0b~xFoV$?6&IJE}LE*l=kL!Pm~J^b9MPN{_kz&e>jYF*3IXwp(dNF6U(TX z%5GfwW!XLQobqUSX=(KAHvfeeRy-0hVxi(_E4VlNi7N1~7mD z3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAq zFn|FJU;qOczyJm?fB_6(00S7n00t%oLU%Ad7vP6w_sDa~qvhK^H{n_@s%e>3n33Yd zgm7(eHY4Ay4`ZEmRr>04v93+GZ=S9>hnq;KEn5w)kYTMzs%#N#~8rweIz{OwHj`gsr)wUuew zc$}?vD0aK4%9OD|H7HUBEp_TY5})LwJXjcI*|LxK$aY`7M|HR%)};;- zV@^CTyemX;5p>$#N_by*(`%bbZM126PPi}b*ON}r$pUSxN{ws>CLY$kd*Mf7e|mo& zX`t0Gc&lzlY@OBVdYWX(vCW72Sd`|P0(Fuvs_x_yA=`QXp13#PZ|dVGIpbtW+6=m< OHc9Jokguchoc{o7Ln842 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx deleted file mode 100644 index 552839e7234e232b5019d7785b8eec6a58735349..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 35663 zcmeI)J#W)M7zgk;RE>Gkm((h>3W_Qs@lt~#cdFE-ttSRrB-px~F3B}MB{@6xIia2M zArOcKRboP7W@6~Vw_xUT;Jh?#Q_=!crTh~`ito+OJ9TnL zP15_aci;WwzPurixt6}Nmnb8GNH%(Mkf*@`d2lQLCM!yDDxnFKLWcXHJat zcp^AmT#yQbdN6*dMacTMl1ggX%dH5l)vMjZyewVm8Rh1eK@;cE%|)jhWl^42tCK=D zht{0Fv9qg<{K_ReY`0o<6Qs=fe^l0b$CFWzfk=PqJRkF#3H1JC}L zHg|GHv71h%y4%zIbFWm_{(#&JOt@ojQ2R^QFn0!Uy2doC!QQ9)t%3TZ{xARoFaQHE z00S@p126ysFaQHE00S@p126ysFaQHE00S@p126ysFaQHE00S@p126ysFaQHE00S@p z126ysFaQHE00S@p126ysFaQHE00S@p126ysx5>XXHl8!ipL^(Ea;B;GfT-~BnNv`2zaDfT^zACZ^R-hIVGVgepQ5}q|tg#b~69* z*!;9PLkreie`+LJtRYmgT<+i%txaC8zy7~1SY9X7kM-YdrMx*Bx?w$|^(5*?V^J?z ci%GC?vOUp@xTd{c4%{qdxGecfz0Ku*1BS6`&Hw-a diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx deleted file mode 100644 index bc72c9b350087c9911916a7e2272f6bf9f52d7c3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 35397 zcmeI)O>Y}T7zgmIQ^ngSFWqblu?(TY5-%}m?(_aXjhhoyT=cg$7n+^7zj32q*6TiZ>xxvp5LV^7yG>15 z7?tMwtzUot=+R5hd%PLv8Ou~$<|aS3^qg+UhWz1d58=JxEJHp`xxG2$g|Nx2x3hL2 z>_L^%=d?ED^)6{SBHS2q(eKaWX%_FgI;+q6efr6xg(qmY5XSNodCpHoT1mrtOzJrA z%aXZHdU*Zt7kJ@z&u1VRS zaqLYP&RhHKodsUpRV#88cE|X?-GzIjM47^!TVGB}%16p!nv}GrX?T_hJ}5Rl>*gp?rjX^~$JA81s4VY=yY!Clg*&tn(vh%n!i{<%Y1RL8 zo!t2y9?)(fjO8ctoS%xcl7{t|)N$UIC3BtOr_^=vA1=i={5!N2ZcIXr9P(6ZH(hY} z6+OIuy2R;Zt`R6rh1Rs3tdc}xe}^OYAaGVzC-sH$8}t<eY{;uTFFngi%`_X>-uDV&Q6}ZdB}+u8za5O)s~hY%AOD`tD5feg1Ff6*|k;&sg{eooDvl?$#<=jE_q9z z$e!}MI<8bJo^FJmjw;EZSf746oA1smYlF1RM7X=uiFBOg84u#FTp!LZF3pdWwc#cP zt@&Il?{c(U6z>MzwrnPY@#@@}lPfQYg>GOx6X{0Uj`CNXWOx$15~FpAwRX1eq}sPS zl3B$jFxI4dNJsAK_0xlfN{iRdT;l)J5mT=dw{LE7>k~wjl;=$=xiBbDRF zGWLv+yPj7|?sT$k{`bSf-8^cVxRb2gg$sFPY*S_~V(stCc2miNNOhX!A9IDBqO;M@ z&CK*ayU%~MYd5$218XxdQJeTjS^l-}6m|x1`p$h;L+|(dslol{{$l_G7{CAqFn|FJ zU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm? zfB_6(00S7n00uCC0SsUO18&%RCf zms)te$5PpC1&?ouaeR ze*ov1d;Yn<+c~vgd$Mb-4y*}-i(OqeD*iTq)HHENM{hPo-UOYll~s35ykBtK74gs& zb!8h(8JklrW($AJi(Fyn&7N56q#LZIPh?N|T^(1d6;C%pPe+wx(EUV|lj}d4xIR~y z6ASLc?9x~tN?!*_xZ;~{#QOAbS^T>NtiKS2uxtNIEA>~GQ@7lg;zd&RhZ9vVy6+{$ zc6HFuX~hS&nXW^rTy2x6f4(54f?)w Qu-d%HD^m%B;J8@)2dG?(4FCWD diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx deleted file mode 100644 index c50162eb5bf8ebf1323b44e73eea771cd3880a64..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 69284 zcmeI)&u<$=6u|MU)7EB^M%yiEQbTA75)^}0PA`>uA@zx00SUyVm(j!?$0M(I?cK5C zR6_ER8%Kn=RN{ohAJGd3{s+$d8!%qmsqHig56#fiK26 zuG|ns-OWst#GAkV_TGc1p7-#2J?VrUQ%{pDX!i*`jM>4UVAWYC-4`7 zu7h%E7TLWv`6ix-Gqc34*(}+UnN_KkCT{5AT)cfwd|W(InU!j2vS4FwopJTz*O94f zS>3&wf4QXgc{K(vS-UOgYVXZze|uryi#2Ya_Mzj3Oq%%~FK@|ttqk5XtSM{bX)nQMH`SSr?cyTA*4c2L zTwZ%*t()#{4@V|&x~=(GtHx=x;)~bfen+-kFx$TJ)GJgr#YR6i4U_8nI8Q~}o(nJE zN=nS1m)Pm#ceGSZt5cbuv5Ac}nV>BYHmX5{p?`qLBV=84N68{*PulPHgs zoO8;h$qCA8{*;;>?UlC{o<7ofsyTvgcmMTcDY$psGFjK{+th_THjZoVBGXY*c3Mgv zrK;Pi{8%XM``+%Tu(C4x=vDr$ow>M`AK06TiQ3Y?;L0x}ue3jbGxA<#HT3>!KQ*}i zTz?E;00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n z00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJVBr5`pmy=AIx0W&8ifhi+68Y1seI2a zO?PSf4jy)7c9YkpFBhsj9j&_8ZanjRulAx>OMN;TtG}%|hXk6iwXiTUiZu#2Y0q*b?i3I8?UYl9_p; z#cJt~HBl(-KRysU-En?fpU8oV`Z^1%VMEuGhE78l3_cVUm;dfkexY-x~+=Pbz6Bsw?W{~ zcYF7PyL;n~4?h!IXP26E`qGQl)z_k$iG^9UHG$On<^8V_{jHa zZ;I_%)Sbt`=;Ls1^oF=zPdZ`8)YBwOn${(}CCUq-P!Am$({Y=jA~(JzZZ3|;_11CS W4Y>MAdK7j~Y?9Vnp|U#my}tp9Tj7WR diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx deleted file mode 100644 index 751416b47d2eb8cccafd72d7182c2439f4b75ee3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 68160 zcmeI)O-~a+7{KvvA#|CFmJI@_A;u;~SrbxDBwU0&iwC`VnQYskJE5H|-C3%f`WcLg z@nT3gndo=m;e%?RqM$m7ZpG%|_n*(!4jgBFrXIIf+PdD2lX%;TkB+!uIbwd( zT4l1aWL56&?45=*H(H+sCURaewARRBRORA@PL=7if?pETCT+_|g^gr!^1#WDE-Dlk z#}P}zy^-OzRb6>JgErdgEKb$+7Oys+EMWFZOJ!;+v|DLQ{C32(&xzX;Aa2dFD2I*4 zbxU?n@?Jg8_i1Uhyl^Rxa^ZnzZ0KkQEou2eX3vdcL76L%@U ztcCLi0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n z1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q z7{CAqFn|FJU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJ zU;qOczyJm?fB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm? zfB_6(00S7n00uCC0SsUO0~o*n1~7mD3}65Q7{CAqFn|FJU;qOczyJm?fB_6(00S7n z!1x*n<~OZX#_HI}s@&b#lRc%Ik@de$xkpaPsr0AdksI^R_`V*Q{a|kN6{+bkc~j^` zI_aqHyXT^yO((Uo>Me;!uH!9;d!DF;wjRku?<+Ck{wj*RyHQ&fZk*bG>biE;b=7+z lmItFg9T|1jdyuHzRGGSxEoGWE%H{})q7dqp*qHaF(jSTCB_{v? diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000000000000000000000000000000..122981061479aed495535fbc636ce4ee0e629fa4 GIT binary patch literal 5949 zcmbtY3tUuH8fOq;xFW(8g-~0C5$4``+LgA6bV1Hyn}X6B>X zw0vZ0+nQK_Fx;8Bk7I6`Yo%^J$cn9emy(vhNX^V#^TFO3WMp6krqEvEu-F~u z!cs?3sl&`BM*3w*>ei<;Npvr>V9%Hby%-Y|2ur?gp2cA+wBNErd%HuD$c`{{91&)jYb!};ZhfgLRH1Y% z5n(MVvE>)q`|!I{hg7IkgoEf~wa+NbR(uIusu(Xz!-tQzIL4Reo1w`>YM53{?*cHh z%tgh8GntzTev0~ap_LTgqJ7)5rlqx?wd}TL4bbw>dO}6~aUzJmcRL(bC0G@5So7>z zW#*f7$F-MxnI=NW0k?=_CRt1JEJao`=jWA-Dk`>?D0pUXL?rfT1CkW!|3SwPk@nI8 zbDq_bU6K&$<*pAAU23;w+ls9-ytY{K6TFlT?>VEez*u0LQCwJ3nBySk5nWGDw4)Ub zC6n9OtFRb+eG@xZ@-OcKBHkb6##SVwZ&hw*4su|yIRbJsz7vg>bZzdCuD^@;rxO|l zM0uNH&{a;8AJj+((Phm5g)HgujFAYJmw>|sgh=AxKB5d|->s&KS z#O6;Oa~VYR6lz*kFr-CiNm33?*(8k!YV!e>SCw#J1fnuzrySCda0f+&`Y@(9eBhd_Y>8mBN;xG>rc% z-}8w-W>^Q}AQ{Kx%m5u|>d{8(Cnb*vs!XUA1Yem~TAxK}PGv1SnH)Rg-8rdtyrkw9(Ajz zrPqwT8F#n%B{~iw<8JPX0(Mr3;l2*Q`G82Gz+$tTbBirSc{HacVu*~+nrLR+W17)Y zislI09aP|vbxY-kg$EdVh;L9=8`I3Vko1mk+T8^T$eN++FfC{Tt{vtlq*>8#%31svp4J(rV!+Gwzr^q$UjH zOBa!P%VE5SPeRdJ`dgKn@WFN1&1CLT?Mwl*d1tx*vqsZdgNX1i=dJ_#h!FWBMOZMkUNAHxC*FVrX}D! zyA%&bnedGo5BF)^du#!kBE`yQq>)@7&r$Iy&yLg|g@fDy!)pCY&d<1`Yy&3%H|pHW zVTzZmd>ljcUp$)~Ea&n9{*Q|?KF>7a642<;8jXw+3XsdkIP|xm`-eq40^F$4+O)YjsfUM+R_>jC! zG!2@EuDXtc-mk95D{+Hsr2A_&!@UXp+Za*vz0j$bCX?c2ruYdyiyuHcUA>(5;ys%^ z!ZbKve*|92y`W)l;Lavkmwo0VoU9=y~ z5q`o;Q&&Os}z?Bjxw+38x&Q0RK~Zp5PgO1y3qa?xq&9Dh(6!!P3ZyW{m1 zJRT&1A`~GkOQwn$7jEv(zkxFW@{p-?@h0Hu$o1NM2!lymCMKhrSYD7tN_FrO|l2_5!nn4NsjZ zT|j%p-ts}TQuqk%=9=Jp_^FyG*C+aks8L)-udG!&XFB_FCaQ<#1C(A%p<2cA{I>vp zPM(OKF+}5BIUIi}q`S|-FtAzwg>VuMgl7$3^COtvXr3$D*htNX52H~clsolu2vEEv zaXS6BaS@)ed?C9-=w_M;C)JHp$HGE91pdu9S-9b5-FNUC;0A6bUJmZkeu5&=(}snh zmv*gY40;va19(lAAyJN$_qks1R57olj$_8cH$V z%2ZY7#0$|C)Q^eYavg7yu6Yvly|z3lPcr-+y-3&LZ*X@Z8C$CcD;)fV+Ye9aP4FPp z>3ZO+sJ`Mm@R%@wD@NsT0lMUghLeSOx(GPgVYOyv4d!5;b1-U(pY_j-U&ibv9z)7HPNE~1%Jn#lK)9hlD?%jqH ziZU(Guhm^%mju2MbE}8C3gA$7JzOL%a&!2KRL&N->qQH96p7;d)KNH%-&)02B~%_~ zi}7%wIc2q!uQ?8C!N+*IE|Q+K`B_N^Hps8*BIT#hSK>A~l+&UeHD&1Q>P7Ab=}c8W z_A7oEdyv_MBe&GMtK%;4N5$iA6Fz~*p!3Yrpb6}$TkiVYa1Exx9Qc!a4o6=JQ##3= zO4se=l%IJh?C|wTqC0^K9@b3O7r=Y*#kxVl2>vX59)|EL zlOC<(rc$4SkAzR9L&p1bH2xCP{0iwX9mY0`zc3OvQt6UE3G1Z;N=<(WuSUbUWVu?7 zqaI?SsU>a(JpuOOc${O*5Q}9jy~@=J@8Y@C03$$ATUIga&^}gMd9>gzpv;dv)1KP2ykuYg>%Mi;u}JUoI&a?5c!c+|NH zA7_)~6XMJ3voL^ZAXNG*;}+M^FY_051%~E91lcF4Hkd5&BwZPY~&fbuAot z3O)p17xx^CG0aoH+d{9PTj+t zmd{XqJa_2T;#mF_9EEgPElpzjqJ8d6VNU;J0)^hF*~AZKRqQlFo-ojLhCQ@txOfab zD&0`H|3sgH3*m=yqA;f_9?u7{8wXPTaC7Z1hDPUy(sPD-1NZg;cS0>a#TR&Oj)X1O zOWq@4X8KOl0A~sbt_Iw&VWWINE~|{AgKWn_Qt$|81e4|QFx%n%;*3OwosvTLiHXQ) zJ^N>>TZgq3gG{yRro5vaw~nTn={-#s&|7jM=H*dqDOZJ*;i=ZO0=E4dhR1i5#(6j5!!2V_bdYu?7$ub+ozrp8WWolO9Wo- H#i;%V;i9Oq literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e7127357d081823e13f186e2dfe2c98b828d88b3 GIT binary patch literal 6207 zcmbtY3tUvi9%d0?IU>Rlg;2W+E9^P@IA`X};v-*SrER69)%8(si!87T3&Mh8R_3F+ zW%;PoZr7v&gyrn+d5n4Uajn#w547S|zAMSfU!-PM-b&@3U0`LG#jLygTYl`EZ~pWD z=Ksw%v!hm%Da8)Qg3@w_(_UmVI!mpNvb^HbB3o%$YBVrHXkLGLvx{>TF2Pfk;H%Qek@r}g_mme}pvgpPm=2_G z0+_j$lG5TiOgkV@q5;=IrA4=)Z+O(Ktk6*_ZfMjXRAAI4Dk7yb1!Eh`=qnKs!Y(L6 zpep2$V0P`OAOWRx2Wb=47b@b9tAXf$tJ7&yR*a$$)-rEyg{56{Nulk2E)nFQYwBU9 z*vbm5B{mCpjYeEasjW7DXbq!hl$sojGDHGZ5( z4lFfeN2In2-^du`;BaFE)i=2tL|b`VyLiU$M*P#Rf`XLR$&s#SEM$c8?&J475D-La zpQ&^0B>^9?uxb5+8Up?l#zI1?KamJh7ZGE3|H$2!YClI#N3-NU+Hhe{3)@c zz&@$?Y^99s+tJikbAMxJ=K^V@vQath6BW0*j)jG)b*YN@-E=}VFbt3x!4x`-rps

kH4#sy<>SU_%6KsX$ zortS_6X0tL72!k5hIze6U3hSQ9xyDAj@AH=Y+uH(%~O1gIsk!Ek%|7+AkR>DG_`re zgm&Fst1r>D6PeV03Jf|)rbhcW8_pkyB#Nwdhb6z%T2erBY9fKi?ry&}Cq1YcCuL}k zu{}W*9@DT~zE8M|p@;c9UGE@nPKp@S^-Y`n9lJlIE8W7pXyoq}d~R&>NB@kPlX{yj zpf}`H%*$ifQtm1#+gqy(+pE)Q)0`BA&ue<1t?pgybo3USESA*G^3YlSv97Z)%}LEe zBDp&Rt7vEQF&MN-@W8h)t?~B@=+c~Yi#${RD*6?xH?!16PAiq`2XWU>ZCWrVozO?r zMS();B2sTYiudwqC|*l{sZtaEa9s{*=Dn(2830%FDc+;f8fh`wq^*SSi1E}6oa#Np zUBxjR_d+<_h<{umpL;S|p}La;;Lb z1au3*msm3HlhawPyb`rAE6^l%C)z3YF&%(jdAH{xsOCr&r)Xr-#ku-8VTPufnu5o}(Ky%mgHVgM$Y%|W>>lZD_7$*D)UevjQg;l7L_d4J?E`yC$Mz}0YLK|?EdT8Nzm8*6)r}J>~=v@pNg6rKA_TsT>g{#9nu&njyWVe zAxEJK<2W>r8`xm2TMD1X1$dF`C4Sq`OVSZ}rEz-o0rU@jHNGNRGFM6C@IBfK%rZ7Q zbB=TY?G^jVhtVqGJ+y~wfm`t-b#d;4`pKwSe2!jKuXfFG4d6^vFU>n>R6T|2mCOs@ z0QgCHGJ4DqkMred{E0Bia}Gv`tMFX%3vyFj1Wdi}5h{H{(>{s)zO5#BYEbxK(%sxJ7#q#iBFds$s8o9h;hVm*G2I_bj2pG94CB@j|UQ3h9+~uQDyy zYbBMbron|5qhHVfCVul3{JM16o1*Wt`C)mA;qT}fx&eQVdkX2;Rx?zo!C$xo@U-3p z4?~@<7yb=3R(%Z~6b5mns1iPoe)7h{sX{Vc0$l8fdJFRs=3s$qD0r;)U9Ji&!pGzi z>}9sXMx0AIgcj5jYigEyr^(;QtMW1c>)<}E}T z7l_NK8OHbVox&aI-LvErmZObdxPQxNu02NY0uiWG+Su^2Xx0CSzvfQM|D>l#Us4-U zr7#1Vb20j<_Ylpjw;Ba<5Tp}*Vb#giTr{sNCj=b1-A3)tJR!u^rqGR%T`@O#gEj{YS| z*-36zcKB{CFU&`xj;_DiGlMzp&DK=Hm)J&m3!cji6)rH7_ycG=Zfqz+)3j0O!?dRL zAJ+6A;5FVaek|@6v_ojAS#W@#8W6G|R#clvcCPmcbTd%FqnfGuB6vH#*f2yG&7Xx& z!3bVu(xY|U4C*8Bp74Qm#CV5}#-CuCUnw1>qu5sQM@HhtD7)nM!g}eD;_1)gHE1N4 zF4xLQ)V)kRwamkyhrm9ZjPs1yVyTR!7rA=jZM=XQWCSQ~^J-=t+RtjMkJmg8X5ssp zTEmF+AEb@^+QHADjD!~VZqA}tBhPCteWO<|Jmn(gd*yxb1(1(k(nW4Q50Bw7+zMO? z9&oM3C)qUl3-S5&IT*k!5GnnY@rdi_=lP4eB17wtiP{}_xcHMeuzo&ja;=hX!>uA! z*C?FA?7)1O!K@McaWvHf@m(+YRcJrD^cJRO|}M`S^thu;uJ z=_jS16E4?s>s#b1W|@%9w-|3nlf<3wOqs!}+20Inq@NAX$dg%S(+SBUEn^=OddL%C zY35$$EBOr7&wGjUV6bOUeXV@dJ zj}%X!2c)Y??GNfRa4~#WP8H_YB;!ROapPcW0B)`S(a`LAS9;RWXyD#F2! z_IAuXY8PsPbA%Lk6K>kDQ9dMBR4374j(!mtc(f~q$?+Zb}GY8%b@!u#AJ6| zX_(a=-8we|vs%^OK7$H-4^5Ipf|?-Xh)8v$@*#Ym8Mxbk}&1tD660cP&}+@t(tGp5Q^UT!q& zMI4*aiiN<$YV8sC$>K=K62V;>4VW=^!FxX6W=FvD5mpUW^4~8LRSd?M8$(k}W@{zPi zB8em{+E{Nj)zq1%INr#a)gqB#>m}H)4-QN8lrS&vMuY9fh8jIh0er@2xa7V1dQ06> z^1@wT@lwLMWK0@%{hflD6%b7I#e)fO`F(={xdFj6UObop7hf=YBC(6FaDYodQa)M* zxCBf!>LR<^W;2T2DYB8$-&@_N_fQy%5U;p;gVm<5Ha6)!bjA|Af4#luHQ?!f6)rYf zs|@uvt$6YDKTtBCAF!0{2wC`jsLzv(k@kzY2`LZhMFXC)G`7;{6dzDhJ&QzAXQ zo`S_Um`&v-i?PfTwxK4)L-)M6vbtJrt*OjXXRWKSVJool;~UFrdKRAPmypQg(A7%~ z@8qlBN-Wu@0NN2Gt#@HBWe$9Ls5t_3D>;nEu#^o2N!djaf3w&@(r7NLD;L#qj8KlJ zcX)(&C3vM6_!yI#u(;Y(?~^w@l`%EI;FHY|$~Lz8n%C!%U)us-+ls{TF9tc%m40jN*Kgp``fcpeu909Doo4JZ zGP!tmU}z&<5%kDkI|s!m(O|$*ixm80xDCMv5h3 zh2b79MY7xFbNOufQEDP+L36s+3AeDAP7M{Tw3X-f=#z6v?g>?oe-sKOt#1rQFFV_y%=4GwL;5 zpCZ{Ybwo!LsNv5*Y4;Jd&yfM+6~reJDdvR>+PfC*lN@Qj!c{_^bQj8-=qIUO;dxX| zn!(vF2=$QSq&l3cc0T~G00ep;5st_CBSaK6#C=cl%pArdAbvB`$-aP=E7P?A#&$nJ z_PW*VgLYiFS=diM1}fo`s>tqI&Cra6idRrBcb2=na~bTjZ{(Lg9YH+>PjRL6BPuc@ zTHR@n0}~)l^XO0TG&n4<>I(D>o3AE!ALI4>TI$d21YrSf$=OGKD%_!|Vy8P#QHOWs zamV0o{CV!*T=*=hxvHrCPcq2n`dzI##T*X+;jO!J`7 zLat>C9fR6y;VI=5_cXPXTi|#M#X=R5^2Oxa>TFa;?-5ejWo^mmevq_tx-x8*AXm(? zule1!scCNQt>StdU==e6{sbH0km%HKD^sF;54^^{!ynepSILw&lkv*6E)w1e_M>D} zp)KTgI<`!I7-lC9I`=YqwG_IZwyPg)E=7Qz2_pGFk}hs5@u=gBicn4f75t~@e)nsx zHHd)){2Jj4K_{%IZegOjJDl}!ziWtog;>P(ICBINJwct*Y~uf;d04oKnhFoNir8hB zA7hnpd&e$E7A2w5&`{_1niKXt{9iP^8s^!9u9VKOS<;?bzMc(Gt0TpIgx|!kg1Z$> z^s8LFvIM2NPci3FG($SycTCJ)Vm~Xy>T=+9U5Th&uu+J6lf41nm0d$Sz;W^y<{Nf3 zbaCyF%ZOx#_+%caR)9|suiYMyh!zJk6qTRw?*BQ{toTQGs;-M1G z1vGOQu5{Ka6Lf3AX=)wHf<^T94ma~o*NapwT*4;_r}#O{6n8&&kGm)5OZE_RP_s$B z-u^DrPxUb@aKZ3ddTL1t1BizIMfXx!LZyR6|8gp|50Zn(3QoAs!wH;0naLE=Cmh#{ zT#eSvPB&u(Jm(w$mu+rF8&RKgj_X6J(6tNxQybm!IU6=h#*-zE@`-1) zt5Mm8Ra6f?0FaEn&#L+?_f5f4d6<}+b{;+r&$FO6kFu; z;AVIYaL7wFsX~l!!1<87jeIO;A-RCw&Y9eIUFIUb=qx}cYMVM)u<=*11iJwmsAKjN z>Ia8`z0-l|2Axg3RQHp6QnOZRpq^#QVFdp*iV_Am2;1Co>~dx%yH|+Q>E!F|GgY^? z{X>1&WzCcdUpV`bQl;$Lr`SrRW?ZlNQWZYS*W8fv;3KZGJ=a-F=Tgn|YHqcQK|k_MRIRI*Gcf&-<6c(w(|0(YZ*#Pzv<^@f zl*bNbZ{lm@1E3SUiSAa#5XHM6;8kFUuuT;s+yg)0UJ@c11?=f)gdeuAcJ=Y6+R~^G z9J8rIHV8A(lw55H}Wn zFBPL_0Y&+*!#;W`o8s(4eLHpv2ZhE~84=>aACZj;?9pVY+fBYgU&ocDlGKvyVC2!) zm7@GO-M8==Ar(17<_PDi82gqM0#xP&C4%4I!uggGSja75sH{Za;4_xsXeWDO bC{<%2PG3Sq;q@+zOiU_F!U9iJ5+(lwfrjt< literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4213079e30c9267bae277122ec4cb550173c18f8 GIT binary patch literal 5882 zcmb_g3wRS%7H&$R>Coqx($HEL`b;K|nLBd>rBcMAfK^e8x@>*Kv`yPUo7N;P6cv$& zx`;1OD9eWwZA+WxF^^ka6}LWc1+1vuRUV}(h{|VGRAgO6=}r=wHcir^YvxP7FXx`~ z-*fKy&$)N5REo=M&E^%Bdb7<`W7OI#2D7!I)>31%SW_nfElq2w{5=}10LRIBwPvf$ zV7BRN>uq)QHa(RZj)EjTXIG^jy2#XsWET zVN*j6mYNJ!9FU>L$m={9i$n-RwP~fnW~wz0`=Pj4lMFl zs6Tn~a4H#-e5zOjmh4|Atq(=(SwC!ibo{bl^nko3haC4M`j1(JnE{KiZ2UzCQ1LIq zAL0zi3vi}kd}jhw{LTbq#HSLRkpNRag99>hK_zTPhR1yw*&T_Elh6Q81gxarg8+!( z8HUFrFwYZLXRzAz%Z!bB5BabJ?_Y06@e1(t5zHwzTB{6oMm^&P!4vB&MypZmQbPzh zJ|r0iNu7M9%_b^&N}S5T#j3HH78(uZekqZN#hB}B^i`tJDv$J3YcdvJZ#I>iEXFdA zouOLpDdPONvf3JLjj7C1Ypt!YVJoq*{Tp4Y^h`W+L_lJ7hUP|UcqKnOR$(c_^*V&) zwb{GNicVu>TF-zSL>X&L}C)B{FjbVPN7i>vO@ ziCine#+_z#(wICvC(yN_8VuTd)5p~FC$1$DyGTTAG^>AaVPUBw7gswJfwzm`8nVJ9 zHog&ut&6NSdbrwB0*Xs7Qm9ZoQv6?4SB?EP$AS_!3zv}=Tr^2R}jjAKaZ zkE77zG40zPJP{2k_)~=^d8GkGE*%p2ih98X-VNs$3pT=KB8!3!|EZI_eVFkA;aH8q zWY$+&40TlmkYb5gVd$f0k?alFOD|>R&R*2JW2bOXXlRiU!6v@rg*Uh&Sq>_&N0X&) zH~9*E4Of;*QcH3O#yd;64YNpgYt~#oC+jFR5j3MYo$H00*lS5*rq^h!*N!4t#O&aQ z(0L8RzmkY%AwC53nuTrNMY1^Ecko#u6*+_}wkVy=e4)EdMa=M8jZw%&vJ=xInfa`N zGf|ZYr0-^3^z(?w^4j@%TqL_(xJ|ts{*0tulyWDd;Oo`t%&?Y^J4Lc%>WKCzP|cr# z(yk+DpCbdtD~L}eQp^h%^lnnLPjaN`GFJ(C(p{L(eM zJPi&Dthxd{%g$DlyN~gDejRlmJ3&}LTXOf2p9!~Vs@Q4HQ`F&I`P?yh3xA&bHy1w2 znPQaL?04iM$J=lZ>glw?C5kBcUWSA|lC@Y}Ltl-~bWUdr9ADE9(hB}twuwIo(;OSn zCYaA;3T?txcz{y0^tC<;?nF0}ZJJq`-}9&Gqi7DZ5jBEa>`$V8Dns~$y9y0*%Jv@i zIHE4Cq;tqkTrv~mOmeQCc9Oo4-o`rGKXQD|&37-#{D%FdjcFPXn#px+p<_UMH9VzE zc28AHxdo2LP%Km-DPK&!qs~FK^lL&YySz08-3yX-PE&@B66A?d_BOrOIwj4my;)q3 z1FT{?!Jl9w91@)xZe>c8?}OLbclpEG`6`+6MlxQx&PBr8!G4s2Dzt^%PREvM55t_q z0q0&uua-jBGj{c(O{ECXGe9K&XVS%OB_4I0Q4z`spo0Gl-RpkcwH7h3fL|+oDd>a^ z)J;rOSG%(g?spB+uMms4ZfCAQq9>_ynoazFG!F|mP*dOmR}s7X(qpU=Zg1b^$fhJz z8XD~QL36_X8vlT%N5edK&?WB(8zt?@J(LCq%h2K#$VAJxmSzy-rc z>B%J}^dlPn7u`)|3zZHQ{mZG;K1dEAD>&gk4<~R2WhPTdpKx3&ay43ar@9xYP!zzy zGfp{qL-s*B9egi5-o6l|2)}|y+_wnb`~l?RZQ7~iwDv3DLMYdq;pVz^a2BiuKcWLn zWxI{jl9}8xbu3$w)uJp$^XLLpuKk{EgD(hQYkH_R`4_2eU?rDDDRM2&Xi%D&fDQ{U zbGqp(;W=kNxMXt^dII%2=eRzi3SGP4Kef^AU$9}LWIR#gXg2Yjb`2`qxSHx_C+L>a z#hnYKNpvlmN&i*5m_6^JTvHu8=pD=xXd}2>@eYiE_h?pw1jUxDd2lnl3OKS#HK{_3 zaKQPHyOn$_cOkie-p-lacUQ#umrmi8mMD-IrXE%z~1h_bOX*t zUaI@qJ*i2nG*Hhm2z7^?K4z2xBf$Y*k#R>3ST<= zkW!`W+^5(|rDj~K`AQW&%2{HRnRJyho^9i1L$$~iiEh1G!Ao?lop!Vu{shy=_^w~k zF8&v{T%FK$hfu8f8+@4PL?5GxY$h_cW@o($|HQmapH%DUL$pdY8J&YY&7Xr?*>uJN z8|hW>tUI1w%%%`^z)sET(34LghOV+_gZtWEW17Le=%`SK+~iAi4=Puh*#S5aWwow# zFA=^F=Ak#yL)?}Q2|riys`iog2FlJXW)6{iv=!>u*5@IvUBRtWmT2EX*Rb<5!$t`y zG0Fn%C(d&@C)$n@Z-5AB;dgdE%^B1`pwF3;!rzHv{!`^n*vOW^Sluf17S)eiGr-4O zWm}%JhR&m!=r!CL7lVG{8>t#s4`*QdAjj=d_R+UFUTAf+%3Jy=3(98)b2jnSS^b~` zyoK&m#Sq22@8?xuhpGaymZIst-b$Uu%?sA*klNt*smZg~d?c=fOI8M)Y zIXzyh(`#{6dMhfu7AAE>*rwC=j*>VwyvpjRw2`q1m^nIXTD>L)n~gnccLrn1?8U`4 zr$xo_+B{eUOnekQtGKvtikRS|ZkT$*E@-fUgV=bF%~4{pdc97I!&+l=Tb%03scM_O zw9Jc*i#S+lw|a0uM;fD_3}O_h5LSnMmDOu^Is4X;`%6P~_+S_^35&6=w0lyzdq0|b zO*EN^#n~!6c8AM3hCh@uxn85eX5mTUnp!~L^pEtj4q5B3ZU<{U0D_GO!43`a`yj-A z4fg9xgpHl$_1aV#sbym|{Kn z-%?a|TPobHW%T{=FxyjycGAW6IAgD3$zf_yy+~9j;a@Tbo)&420fI;#!lOqcKMm5z z|3v)1Rb-$`>`tqrpY$H4AUsO_#)rh+kJs_fEVoyL1!G}!V+R<*LX4o52ixwze%*ow z(O`jWT`*KUPu^uRQiX;jfoXLStT^S5FhoN%y4VL>09^ zYqNY3c1i?{Mdx|)YmF8Q!42B`%!f5C*#I^0DZ(kaR(=ic%B^AlDvc(VqEz8BcMrvJ zwEs(fc=i(CEoG!J2R_l9fEpDqgQUNUPr~12J6IlEpzm_m#W!F;YJ`$5n(LO6WstWD ze2rx4h?2qNDw|*zy#W?52jM{}-guN1l+Obsb0P@9p|hN# zX2~e$5DR}oE0`>$lo!!I{RFj+?m`~WDcpm@BrB1@nb=N#J{-@-8DG}7gFF-iaInpv zLO+#tjGYK>D4Uw*fn?<_cw1Pk9F$v8K=xAO=(MKC;5?XOxGBvPjBqM+fge#TSK8#2 zC^|!0t{*8b&1@hRq1o(oR7~9v8{uB%ilL1;B=2KB0IQ@-CO60Jj{}7n@u)-DFB#KU z!Q1|G;L&&M&}P)`pAk6Em;yWC->A5zZ^faroT_GNmcBq&(0th9kN3?%hj$6$i|lIs zY4o#n8UR#(Wn5GlIFzkw#;-Nr+ZffFP{>x2 z?<``-pJPU96LjsoYtQ zYGFQ?p)@K<#6o&Bu_i#nm%tH}j7lhz}fQE z+kk4oV&4{Yj?pP!NgL}55nz7H*}}Z%RPXmNVlX3yQ^WlZ?pVc z)GZNBZQ=#QOekfu=~`(FM-sz8iF^sI7CsBSi8wf2ep6YtEsFUQyet*6Z|BjvIQ?$l zC@{jG=zlHkBD;|NK;)aw^WR913;$M(%6jH$E@n@YzXBc!bh8J@VGw6!^a8vFIuub~ zf<6|f=?gNhigy~hx-O-jUL%_LE@}=ekPiBD6dG+|ZX0Ul9}I6PPcihav$92A!~8)U zrp#sCIfv;>%4K4VFix+P=JM~LkuVQw$Shu$^5druf@Y`$yZAV@g#+5*cCpj><;UeQ|W5AkYWcC#0fOnpu*)l!*98 n2*k8WLP*) z5D?5Y7KC~6_Wb< zzP|r-BGvg*M^f~~x+~7=eUqih=LY$4yaWOsC&(*`8^+;@#%{cv8?eM*z*6t`rgwlB zpKWT=-#rt@P7G%egjZ0&3NJxG2=AY9jQZL~bl9K5h*m|)D>8sTw4?hYULr1GsKgq? z4dn*}h455&|2)#XL|mLDV5xF>J|VuMx%6Ht#ZI0qES$UqPT@fuvxyVgdM4eLoeN+R z$O#P#S!(jZpob*QFIeded~f~Ly-u3cx7UcT>a~|uPp@~xS^u60)(~rffGcuW5jo>H z%K{@fALv%;+uq4EON_1c4{;`LT)w|oD3>$7hgY_2Xc(6-!sB#$gvQ{HKy+k!ei$mU zWO(7h9DlBtFMp^+CwHnW`EXu+y2qZY1NbRr#7e*xLQ|(JMC9K4bAdTsuAq z*Z+$6KL_-LUS9r@7?ndeo z-QHNH{bOZ|v;L>$`oveK^c90;{X1(P3F>9lUr6pb2LA#eT^(Pui-K*ZoUuUfI*~;Bb?}_MY{8&y*b=V@|76jOfB;*Ua($|I7 ziL?IeaQYHsZpt?3J%#S*|0FwTebKImoNW@UuS8_^ji>)Lq*)ptXSd$^#U5+o7a)?{ ze;?;&J=|-k^$Kf_NZ4#d3FH42??J48H?U8LGh?fm`1zQALet=nT6;+GXZ35$k`^I! z$-F4_d5Mx!#QRFdk7v?ZIeesa@^;*y@Q)@g-MIjrZGD~4LzKn7g!IWg#&^xbtlxQb zxAx4#rgOw*?-Nt!jJgeasA4O1wf0Iq(%nt(bB)~xzx#(@)~64#RX$Wjy(%k>fn6D3 zd>0U#73>wj<17jD3iUU%lV&N3+k7sGI*h9RX#*qIVuwGTj$;hI)dcqvZCKHkjjm^w z(86t@IM|4VW%HVd;mj_eFBPi}qdF_)uel6obb)=oxb1}cx{skl7t|L5&0*AiYlfLK zelgS@mc}1OjKL}mIdEx8!Ke}CmiwDku(e_?OvjCAIJpa}$M+mYc?T^)ThC;2CMX+* zS#Loju~Iz$LZ9%j@hG_~*&@KHxA)EwYfr{?#G;D@7Z$lAW;DyBJ2j zE5X(3g}6p0pDJ3%!KVG<#Bi6SG^?xIPpT+~3`q;$I_9-e0zQb!p-wUlG)HYAYCF2& zhF9rOd0B;AQ+9zF?dJHpas)}TNJb<325K+03JNQ>VA74naN+WHn!B@r?#ej99DVCR z1(~ZcDZUiuK8*thb^uKtuohd+3-H+rd3e2k5+0}@iV7>%LtE}|%(rPy@F095)Qxyc z{@k_(cC~BaMXL-DJ1`Nap38yj%JY~DmW#ptG6&DyvcVuFKeE8g1oQ%n@PXA8Qfm~2 zgM*}CsRS3B8Z)8JtDQ)uXTpfnG9bL42MKlNXnLa*mijCJnTz50BD{_^xdAw=7eI_) z7_3vhNj(n5Lb72z9DY6$CQb5#6E6#3Xw5$QFmo)~R#ys>62^gZy%wy>%fN_?IB0U5 z0aAg=^zX)SkZFvfV=dHxpOp%de>vdHvEI1#jR!ujSpvs**3)nwOH7r2MB^U5V$LUO z;gdTnz~}rivc@)xymfG;W`kCM=VDDbWVZqa%S?j*2y@}{Y)uidRr{Qaw%ZCZX?g;L zw-uwmo;;k-^Tn67DO9X27wzU&&_9R%OaNQRYuC*fvCtUyTfL=K2e-gIA8lNyH-=2} z@ddpJ_7Fb4i0%qn4@F+PAMet+IA_^;V$E(Kc5fu< zf9ewE49hrE0dIg<#{tR99izc*No~3t`a-gKLj)>n8qjMUD;b+;JI8nsb zK-1CquqYp2m2F0U*Y$93(NW63YfXc4lW?|Jo3OQ372`Z+f|zN) ziU~NNIt{fJy`Xwl(dc2A4-wWskeTap=;aMtV9}YI)F+~vdfd#yW~VGXcxVxNy-olZ z`+V}`Yz=w-U>GuY9Fcjh1Cn*~@XgFJqBABBmbm3Z)m3>c`g=2m<+U(to|}@{b0bMk zsw*zT6T*^_)k1xJ1(>co7>^%oB-0lkrPKJWBz#0X*yY?79^APJg2WO)CP5ctN6LZv zRb_O$m;&`#9;jEdnGSjEjwgeYsJ&GNy|?fho%diZnJ^>;o6AQ-u(K|Vdyxu>wPF~b z6UtoJoY!@+_mPi;EDRrou^O33!yIhGNNjG|6Qhyta>~^%tiQ zGiNEx2uq>o_E^xr^XFsH2|K(SbC-j9RFiI>~myrg_kQ?@Ya;7y`D-{ot>z-s}I~OG$1X$ zu}pbe9)yODq_IZvAghuKX+bLi1x~_6x(NtFL$T@UdT7eoN46b|MawHec&H`@C*)~h z!m=~8^@Jxp3e$#lw{uAA8V$U!=>ezOrh#i2CHcP;((0o7Bz?|Wh_l>B7gwAlRZ|%9 zEX)OFnQVZthLO0C_UfJX4O{lv9K?f2iSwc_r}E zG$*(6UW&^QxsW`e}I>_o>AZ`IH9DedP9roAHbdTyfI;U_p{9Mk5!)=+U z8D|P^oIxbDBp=?eV?lN4J~H5y3H;I^PQry*jCw#4n0jQvv-D9gEi9K&Guut(n)~1t zBQKPLWAsr|3*B1wg0%kB%J{cEATF2fVUX*1tX6qU2NVTb96LOhjF~{FLtY|^50QgP zwQN%Nr#IdWbH?v34?^W{2jJ@snXr4fEa2LO^cNXP;5;a%9VW8q5t~e}+mwR22c_!9 zE_kP-1j;6L(7+R7mBCb>9M+J9IXnTpS)KxR&#ch<*k(~)UyK&nme?Gf0FB?jpl|nz zW6=v9FgOs4mqp?qTT5X@NdhysU>40j=0~KWcM1pB1D4% z3{~Nh?}u=yk(fG(Yi7g7DbtuUi*}IQs8ye@GR)0HY}w8K79u*vkO~yX>LU{{F**^p zCFq0QBQ3b9RstvXW{`8@wis?w0++T4=;O(`7$%(o)#_WoV_pMEIZ%r4wr;_+VRLXp zatFP|T1Uj{OR=dW5JyefNk%(n!d|-=38q{gKfb{Wn zkl*@_`kk@{S8o=MykSG;w2TGS8LDtmdn8S4XeJUXvoL}?9bU9J(i_=&@ax9kg%aA4 zknwUn>`hF^;f^cu;hR((e)%QYU_T!^!t=2vX$(<6Cx;=6H-l_h4E`RyP3RlgLRZNJ z;_ug@v29Tjp8a7HE@>EttuiS&Fj z(Afdaf*A5mQ#KfVqk!@j$>6@B6y&SYNu@B1thxWLqG5s_9C~1h1$8yT;%DX5Y@3s) zni&YTFWq3V!#Gs7uBBsr6VQ&$!60n~_&)m_Jrz3z`9?v|5?lyn!zZC}nmJLLyAW2* zPNvL-!$Q-whlINC)(S0hG(q5;hV|Biu=?R`<`mlkH$>t1dmFK4 z>0{DdB!z(n@5tM#^;Bk2KAeirL*-Rb^ssjteACV%Gg4!~{^bH@pRGpaw^9|tDf_)G zCF(AdqsfLiKGgzwlhiRnL7!2YY>U>G!*I``v2d?oH4O?&fYOF!kd`t)-pdkrJw*bx z&3jBwUr&c$PwK(~{iz@rQ^=HrY#^Q@zF(Cmp`7b#I!sjx4L2pyxfz}yptE4jz!;io zSOUY70$@ zdH@!jo<$c}dSk`w01VSw08^_?;qJW@FboYN`c@^#Z}NmhmqKKnUjk#oGI6Vt30&Lj zK#yx>VcI55P`35ODT5d^D^{vJ?UI5vs?#C#kO6G8nu?K456SYQ6LC$m8%g*+lOFI~ zO*(e&A+3+BnA7qy#9(V4inSOZdncf6v@Q+5nFbN#-Kn918{SV!!pg^17&vkO-Mda5 z7aA1dJn;m}x;u{W%aMDez}gHG%e{g3tv(40(1z1T9dS!e5t_$8r9ar$(QNk&JnNPK zvx@A{nHP#%ozGR8NQps3K^BA$B-r8X2nC+Yh~po&OiV)z*=8`8UfeeV%1lb>%;9Bp z;X4IJf0YOP);l)R0$^nHP^NHc&7%+FVf$e4+QEzNAaBmXWDZdgIC~PF{zoy{N7An}0l1R$V z1~Bu5T3EF5w$SOjftA;NR$=L1&q;l%B1Vl*qJgt-5ng5zNFN)G!>%5ohwuC1+4^{N zGGfElImd~YgBzBgLb6YO8GiRj2H1^_!r;V(*m8X)+N@TDm?8yi;w{Ei4feosdBWJr zwUJ*7qH)X{h6H8Q(}POuNY!&~-2O0$8dhXtTx2W^3XT`Ow@<~maSsWft_A9FjXokW z*j}4vnGmOmrbFXF!CjdW^&G1GgA~jdD+V(-yU6zva_IAND;0k>43wT(2M`Eg@BuZb}D-67{78iw^0J7^VPfO~N72I{G zU28?}i(amFiVBtTq2J-eA^VAR@-dR^?}4yy$)mf;RuaqD#}&2Q^^;z@#>5RbkPyPS5M+}R18|yZf2$} zih(hP}nsh8CIU22jw?M%X>OQ*Pb7VG~r6?`VmSssuONcI3=jCqn znUmB$#wwl0-iEAx_Xhjeb^(iQKRLc(4`Hc@2B7$10NsNa;5~-b&jHP6t&Ldy94vea f-8tj|-eXz)9Xa%a-DPegp{Qh|#Oi%KsVM$`I|c{6 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..677dcf062a80d80f5d75d954a2862d8718449668 GIT binary patch literal 8430 zcmbtZ2|ShA`*)+_X1&OkNw!20EsA>{Q6gDlS|-uW8sU&tOgr~H zl}t4=YBX)yrXs(R7FwpEeu}K*XO+-_q@;hJ?DA9-}61s`yO#|mU1wc8xb1D z<@*Qu81h3sxIEwB&>){sp6W1TLlYCjiDm~ZnHDUT(ac~jkMF_dbArS8Az^&ZL{*8t zHcj!~v-%3mq2V5ZVLtlevJ5*h5s_$%5{4LaCR5&@+jW(nzqhv!mm@^t`|y~9pt7Wl zEWN!ybRyREQ&Uv<#k|L0_CLu~(iMEN!FS$gUe; z6u=1y4PIgNKA?{zwJ$_z4(o~j>RzW#9oTF5SM}O2s;}2r22=Ut6byU947(^|4#=(_ zLLV!vG5xc9AJtDl%AX7A5!GD=^WUq1Int8P_Yuw+p+Y8dRtAJ~-fK=_V0(8iA}sX} z>M?Th;rV%l_;4nDNJBCt)Q2a;<8;YVb;L&?ni73K^ktb6+^`^ypO1$(PkBg}fE1b1 zVO)Q2|4<*VuCaIoDtGD1Ov)=b$S}y?D>RrF?8|4aW`4PQ<^XL?p2pJpXuLux4NkJ^ z=Y#q2C{FiXGrcA3}O0+s=?rQN<`+ssr$dW2UGQ968c*uYl4xlztG+P zy``A+2}V_^pM!9OinJ20v;Qw^31t4efqg=p2}{Aq*UR`5nnrxo+DD2XbI@R>xDcUx zh6-0%58*l(-cz~xa*etwrH@Qi+Qug^?4!iRyKX>N+dvVz2^Y*SA$^j^r0zUS2F;_l zwJ#5=t`VE`cuH4By#akxvE;g2`+0_VZ&Qy)*u2->2Y#8KUc^#(Uj_E7BvprXFE*39 zfmqBS4}UIad8kK-pT3PaQ_TX&WkxRz+Z>b7ijKD(S2 z?g+sV223nl*hKW_b_0E>IJeX5dNTN$Lw`;;*yoG+E~u~j=-YNf^{+X5ytSJRJH5)% zQYao%4F8z&QOj}z#_+-r8qZdgD}GN!Uw(_%<~*cJkMhZ&NFPMyv*%A*2DL}T@y9Xa zu}W16oSITFa!k4Tfo274tyloFaWm>q>&EJv4|aNd_J(-*h%!W~x+t5Iy6feBg zA-romP-R+UY-SpX-|ht;uI$mJot{!eA&L3BQ}B0+2B;@HV6A5{=)Eh!bt;9pULv2$ zTE@Ya0}P_STU4Cc-R&nuq+N!nnRgw#-5>!UM&?ili3XaZv;;NmU2xOubf~ z!JH0LY_A+clFX9Pz_x+fimiphifx#5a~WK`@;%MnRX}%VoMaz+Ye)H+>o6(46c#*- z13Q*KO&%JHEf@Ise6=*RZ=8w;>y=Sv^+ssR{hd84%>f>UO@X>GZ^>Uf*2C@&b-ZMe z0U`&d;EeM*a6^6}`=a?WaJ|C8^S7-qP|lYuHZcP2fFgWoah22>MB<1*F<3Fg2b&r* zq0Xa&h^A-4m@^U}c#sDPb*5;1vlLc%Ee458VfZqvjyAaf*lpxPG=CIqP`pLm4#z;U zemop`F&3sy^@WqK3P8DLKYf%rf$XR&g{cV>!LeQ)*5+kkct#vF+0Oy70D1aPV;D#@ zM$!ppO2ErX1<}9laP9<8-2TQ5U(_sz6T9kZn3p-GN)+KA!NZ;GT3ihKwrIgIi*?w3 zQ;urA(SsMaJkfiP1@MAfsa4hrY`eG!*^rO-=mMO-@&d7BH4vLOqVzxYi4Z2U5jHco zljr_`-%C8CDcK3IKqC=;yEFj~Tr0+3r$oZOE0T0gydjz&X(ZR&wJ|Lqm7Ggr;3%_X zw%^Lh#A(hZJnL%U_InO6Z-is&K1!yitANhV?{K}c5~LeW0;9=B zI5$oVE_B{!@0hiZ8sF7`z1D&3gMY^0HJLR?R+u8UK$jhEl@4P$CurkNbC6u*S@E)bT8L8(EHKE)90Fw%aA zwJoC1O+O#PEq^3)H{{SOo3_Eyv$v>Mcr|sqm4(d?S$OF1QuJs~04Ljg^7LE{dGT-* zvhUg>`-LWm)-A+0bIXY4_&iwdk`Gnaq_OCqtr(iu!e0Nvn0&h+g5;z+<4Qa!C>dKV z(9w~B*;*s;#PLQld)YBMlh;bZ#>9h7&K<#_U0WbfBmpE6v@m9@6sTO2N0&<}P@m<7 z+BIA0$S1CNDkzECT4d1sORm#}4`a#Xktx_*J`RE$wP51QR7k89!T6jI_Qmfr=!qOZ zYOK2%8(Uuy*{8Ws{JSDLdYY1?hYP6t!aOXy^psjwJHVNWweWF1zP#_0`iR^$+})9xV=d7^K)zJ zR3Td_0t;NPGZ8Mkm(g{_MF`JQ;gtLkc$}t$BFWk`$!P<$+eXp)OVf#oqZnp{rqJ_y z&FDY*i?HaV4c?2sM<)KJf;vsX_;Xw|=pVa4iiRXXfhre#V!f#0I7w>bHxD{$GmzVn zf@%p)^h{(HO$#{6w!FF&$G=>SQ4JSKqgJxO(lsAu#AK73sX6Go)dIEao7rB^ozd-e zBI;#Jkb{cvNcvh?xO(m`yV}+Pbw?LNgy?jY@iEeSR#&neV5|Do&BA>1^^m)CuMpZGzB- zvABfzetvG58VlJvV;qYK;Wos<|1w!+nSoo}Q;y$~l7;nuD&fV2CGhi12e>_X1?4j& zrY}>nf|$Bg9M>=#B(*OR7k?&>K6QtV`ui8US8)TKU-&KjQqF@TZJDSRXACZ!;Uu*r zAKtKHKyk%>GW4|({Mx`EVS+5SihmLqyJf-ibUm0En#)!)*+UkXdf`=`Wy7`uf2Dj^D-hsCthND|#*lIabrQZRL+RK?Hp;@*$Hb*5u;}0+C+kFfyddUSg4#VP= z5%}l!QdnJ*z+O-=k7ghDC1O##1S4wlNb~C(^iUv|eXeC0KAS`76^}v;QSc!@jP#)f zA}S=VnFX7t&t#unx|8HauKj$KVQL~|%VyCm2=5$EDv*KIM<-)qR3hw1&;gsr>Tpe| z1WxYDAmnqZ^N`v^KnyhC%w(wKt$?G zv8f~g^``G4FhJckt zq$+DRv^{kNNmm~-j2nYbIorS@QbnMBfeqTnm{`Bl3$~f1L7j#Zh@VIY>Fw{R?`cbL z_GIGNn^t6g%LGuIqX?HY#?r)wW-??=7KZ!GhLZI!4QoI$apmg_9dp{ zX!|wz=uIk)zVeD}vRwq7Vfk2-G@huOm%`v>TS2lc8vlseA@B}pp=+fA@Q>?J*tRqY z&;7Usmp4qr)*2HUu08`N4HqYm#p=jTe>vKGOd1ZEg@LB$3fTTvJ9)k88p%DCLKKop z1(yzNhRf5c$fmc3xF>ERIW+Bel6!j#7RX9qO_d+}v+I`WBq3Xxd(RTz;R5<>@g4RS za!YX8dI?eVn+B;J>*1xTBmVM;M~h2M@aM59WJz){j;V@*KQ~62zX{JPO!4g>eZ8BwFJR&fT zJuJ|A7b`H!Q3JkX8rEA5$LdFS*r(Yp=#`$zuHI!1`_-nv*vDH?wK{K`==6Z{xuDp{`6d{ zwn8*qo&@$cJBaOKBlyvaKNsoA@!8aXDGAA_}Y+o&A@3&U1oF!Hvn0~<1d`R6Dax7UN zC#9Mpcd818%jmG>rdgw<`6%3bcmmunSVsdx6QHyq8N|hOk^8Cy+NTeJ9Sfh(GdI%V zw^LfMSZ4XgdlTgZe9UY}8hWc9)>4FS*;L~}qepock)GvY2$?;J8 zEEHCn7!XTCGf>NR68i&wMBXYMZ|^O@c-we7=e{$Z(po}-;zWQ`Yk}7j z-jPG5;ZXI*axjs$#%G=C`14B_{Of2wOm^QuAI!PL@NXzc=nW%SYItNk+a9 zJ{cbzLvF6ArKR=pU>Icp9EE)R<4!E5RMn7;Q}l4I6BieSEyP#zLSde)B7QeM59how zBzt~1O-ywYKqs~oJ@uXwv!*PNG*f|_t#JZ1*Gj?sTw}1WNk;kPN2KuFA>t^%oc+RK zDbB9Xrel9%v8N5qgg1x8%wEi+G;m`Cq%|j#mr^lMcIryyHPc)Y8q!8}qK0C@nR#@v zxhGb%`(vp3Vwh2F4EOG*fPP3Q(Xl8&UXwc{Iu#=G!g3fNnu*)xjNtk{J9g6M9~Ly;C;WbFd9j?$uGx6&Ydk}K7halr>kNm%*B0t3blrTaFh;1b;;T*yc; zue)mxzaG6$3M@@9vD_25vvf$PzXqH+W{=x)iqJIv8U4|=j%K@N;5nBBm{(+jj@%I3 z?s&e^NK6DO3bG(<7{N|Qdnj;UN$mf$W=A(flO4JX=%xK*pv#TKy z-#g~OXjOTPU!_RhX(8QtfQ@s{Kcw62*5b%=U!>-8prF12jzyN>^Y%m3sVx z&p05na}m*B77eEMR`9*aX4IaL3_iCA?2=xCi)A*Gj^9%77c&LyOi3hV=lt1=1nO9{ z=8nMOyJ3|#yw+mr-!Djgsw_rMN}>VZ-X`44BoIG74o6)(N{>A7#&h-Y=wQHt?ek9% z4?7nuKaFI+^h*5hu>`Oh8wEj$OR(j}T(nvz3(-X~*u-6iYa48V8jGOp~@MtVhyX*8Zk-(1HH1mWw zH8fU^2N_p+N>p;F#*bn!XMzaK;p`?q43R>wSKBG$`6!TkZYEf0GZqhr7eY&p9Xi!m z;N>tg9AA)6CLD@G=SBw@r4xav_M#|$d7mKQYAh}dGXiAQSDq2oCaZlmphm3)eIR_e z+9fPh%9X#vDI*UM@#Nzq+0R$-s`@(pXM`5j&tY^s+MwHUb`dA$ZGmmZu3&8}C7ixx z^tO2c8qcVxjJm-?m$Q+0S&$|OT5UzX?zNJLh;(HfR`0fA@!2z+|z{1#1wux9HnF_)I$i5#y?;wVHj%NpWJLvJK|zXwA!iU2DS{|#152a?i9kS-px6sdDT0UvSOl^6 zwF^1pihUOq+lt)<`4lS_bWxE%;EJH&`tARb=kh$6JNG?v&b;%U$+^nPY`rLBKWt;8{bLQ!Z`Oe8->WYEdR%GTDZKj-)$#+l8w9ug%Ki35dVUX(-}EfMqj8>qC{ zdMSUpYoX03#Rf)5_!i0NQG9)oy#I%FTR<(8s z3HdyTV$;`N3bKF9pDd>BO~z2fj|k-jip4@+L|`00h9{K8ZZ(51m>MQ#^xM4|BnT9- zZLHh7XYJjDF+v6r7$KMyC>BHsKgVJIy^r)_e+?sB6~(}r0+C+hr$1H!OWs10>CBH7 z2_m9|I@?e5aVTK1n58V{zYmVlXi6f7;}a2`2oXgK6XPzYWI5u}v_s{~SykmzH3T$}|7G!fjmIOsa)eZCu(6 z_1Q(+MJ(pa>QlzX{=DhovAiYV+p<-qh^X>S)WVNjWyLYPz$#~p; z_6$09YRTU`Mz=))tUe)GH81SWo3_#cf;H?0(jv z0+|%GY81Yc|F)m`uiJ+1JXyiJ%Cz|ZRRr?zSZD*nKyOU-*Z`bHiLEC5H~@6-=Bq7Gt@T<&CC40 z3zs|EZVHqiP{z$Y`(mkqD)`oj5wGjg6X0{3^IQ~VxsG*=^Q1@IXRO0I|X6!TLH!x`NQDjHt^@Y zRIrGSA*Rj+*tm5EdH&d$ds0n>m=&!;xmq)1Zv}Ll)0;|erbDb<0JTu}!=e#qORN>; zpg3naNIDU09N`T)lcp2zi*DQnRSU>^voZA2uAX2qWdS&NyTS(Bm1xp01^71!Y*m|$ z{^~19!*8j0f}@R%sY!%gT`i4Fnt-*}hokE}4Ooz;jy1w5IJe3Jc)m}$ZmRX;T@huy=nk4MIMT81eQj5Kukx1t(qG?Yc#?=ul(@#!8Kq%X)(R;nTd|}%i#r?3e|!T+q0ub$2vOUgF4rA?nL;#!4rBIXko$(9Xg5T(k;8WIQ;Aby6&`qJ5FkZd9&|G zea3exxgIA1FR z9A(+sF8XSr%z!z`#QQ!JB=wx(__Ikzx z^)2Iw^5Gw``_+T=!2J+BU73JBmTV{*eS`#h`QgT6NOq}B$MKIAh&k_ z4nDgAu4|3uUT~TM0ati<_LeI~XoeDhTWc^0&%+1Ke~@y^IP4Ok2-D>GSW}$|6@d*z zVOb{hJgEZG`>P(BTH^=cwL}aH#NA=B&P_UT|3XNyNPq(` z`@q0~q3~;64(OHbq7O6sk@XdYFfg${jHon%xvMfTHX|NtygflNT#NoyEdiD4INFb+ z3!>#|pzzuYhxZG@qIVPVW!Y3XvbB;*f}Jo;?JD8U zpN;Ngis@~=69ll9ydCv3#!j$;-Olf6>AtmaFW49-nDiw>f-UB3dd30OE637c& zOr_5XKBI!AjW>q99=#xP_g-50p^``+Y~hxpKkQuH8xA?o!?!myY43OD z@bYF5hHQ5RQPgwlx_lbeUl@m6Sc8AkF*thqIpV^uBJS@L=zl7cK%%|`Rx(B8g#d7a z$^)93l?Y>uli;^Y{b2Xid^|HC4t8GYN@pinq0@nCa&?jkriZ7I)5$FC&Pm~hO}8h$ zo=fpms0ZxWmV`GKmJ!Z?AMnucD}cEki)lM48MI6vOt*~3`Fgsr%*qa|?X7Wmyds=y zyvJQXbSJgBYYaQwBDi}mF2t+qvcj3>fWjOzZmjDv=)*fgtG75o*Kt9`PaBq_Tw^BQ z8CFcko;yY7o-u?5r_=PV<_ah%sUWPoa&&Z6_m1~1yQXZNv{{6SvF^l2aGA_-$-vc< zQV-u$)qwdIb@9U30@yRe2X5I_!KLLY1j^3f&U!R~m zbQaUmxj(^)jUqTupNWR?HsHtWOwtP0z&rLr(3!T2D7~?UGgT}ikuK-z3zETRqHHf) zW)4GQvbnmp+sPP*VEn@}5LMwYeOyyZi`Km&&-Xm%hSfhHzE?b;^C&w!uKk26<%M$& z9~eXW+7s%vDhXNLRG~ySi&WeW!ark1;P@+@QER9YzFnFL+k12cTr`26QBeTigM8X( z-4!P;Orh6Z3qdAGRNu-M?=}{|x`B-}{8zb>NNP$B7<7fvLNUCXkqYiFoH6L|&ob?q zf}AWTd^9H!syDr&?{~5=@0Ad^*a`Ee&%}#Gg)plikvk@5B+WV;N)+d8m3ApxMIODm zPWMF!xuc*#*a{46x5gS$x8p1s_c?-#oo7?feWj9EcE%$Ll zA-1tEDMl6^KWL9hbCO_vqA9pPHiE0V1@P<6404v`h7#)nxV&CWpEzb?jB*AX*Ix@0 z$5xTly@mLpXf399AB{^>8tE-&F_Eh*#F~O|G#|8;{OFwp^FNFv8^g14<*i)g=gXlM zFCP8YY?p>@i05{Fy$Z@M1Vc)gKZH-uBc;oSLH*MJ=o-K$orDYVDQ_(}$LUK=&T+xy z5QCLlf?+Ku9V(1{LHWosP%HXCLyx(@s2~RW+;AnMYx{wYrw&{)?n9HR9ufK3%Q2Qe z3|`fG(;HbP@Y~AsQhDQz6?o0H~s$$fX zpP}oz1$chWdTB^_EuE_xj_0qYO0<&;rI&WEgv*YlWa)b=+#cVb>~s8`WZzniIT|Wh zRvO0bPzmk(%d*vL$0-upSxTSz-{Gz%H>H=|CJ>!4M@Va!53d|X;E9JKnqO#(d-|l3 z2`Ty5vvdwzTrvlrDvu+bMtGrZ1eg3!lLeMPsG}Mu1p<~9f?DY^QX-8Z^Y4EsuCh0Q z{STZlr=m=n|6(JxmGzU7`cB|h=Lb`~`lFUhIqesci0*72Mi{HZrmVB{*up_5vW$S* z$Xr<0V<1|kI}pt=6JYL7DU^HRfYfHuerfLyi=>@&2EFsC++-(=-g1LBY)?e(`dqY{nhf4I8i8i=JeFvez`afd_a{vGU literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2f00c55afe9a53341e24945af0682113ee1bdb91 GIT binary patch literal 7664 zcmb_h2Ut@{*Cq%GS7`yH$%+cdf`SwUL*^hTQUp=f29`((5`lmuL9wGWrHfd=MG$*m zyO24q*mqH}t=L_VPqD1n7Zv#ftSAENxBr(sH_wwZXWlt8?|IM6y;n(zts5y6&We@@ z#exXFr8qi3C<=*;j^Ia&^gCNy+SphQ;Ouc^oY-uup^-w7I6x@oMM}g`5;1RpzH*DL zx6;Q}b1giu(+(NO{i-o-KfLMMsPbiDsawcCeEmX|twc8jd2oSNY zt=em|@@m8wDT4?I7t9V23nGP|;xPN(F1^`pVPvbK5HL$1(rx(oVHL3C%{7=V{3wwi zJW{B={X}2;0v3x|%3}We;28Br5;+|20I|0uoM$t@n%(PDxK=ZGQPGjpt^R3kp-;CX zm)@P5U6~k0yJej$n^7zq*N^!&I{SyV=&g9wZnwRe=>N?k*tc4Qux~FyE0>l<_@4x0 z?a(Tim~RKu%B3ZkR*W>ev}L4K*Om-!#mGM{9WwH%y_u0s^2|2~ZIu(PR? z_|z`PG=5d00z_h77(bTRn0$;{)8D3=(qv#8ewI1Dd{JmX6rX3?0>W00isp;>vO4t< zfzAGC`2?ch)mbc-IWX70Ma)qNij1&~5Cla>iXubA%xvcC@qMXP8+&%2&n6^OXZyLK-?39mcFbpV zKI!!%QgeQM&lqf@_Qq(HBRU;;G-pXen*#k!#D87F?X391pvYiZ&HtkOf31qM8B$x6hxKp52)@iz)e#HfF!{$`7uEKOU% ze5?VVG5VRht-nv*UlO^P#e8#^FG;g?V4JoM?V|?U?7fCx^=lz+y*%@sh|HH+{cnZE z=zqSrmLF%(R7KeRtwIE{!uQ!>>pRtE^GR(wt;EHXrI3s4Nb&ztT{!dK1=fbR!E7z7 zkRacg=Ix7%Jx)NcRe z-6wvTww>5ovMg$I_}A~;REKSwfY?k#fI!Hb799{3YVNAUsIy$!KYDta9T~L>)>`|6 ztF;O~ZCFQdIp(7Ekm8cL*G1@iDi$wG)1?u!T}WG)<|9YbggwnvOXBhPHaEKC$z8H? zlLw%10--H){JS@BpHjw{gIp$V%+Q=Vpl;nUD{atb*W zo((;m*P?}7A)b3_LPS?ZpwGDCqO5e1P!t58dN${_r?AYHQ?&Z#>;n9guSd{`w3Z^K0RRC`X%ftF{Pw2_(8SvZD zKA`E6fVXz!VuE`D^|&`0kM@~JBI4zMSMG$@5=i zU>+4sOq>d^Ve1a^?6DK~q^dG8En10kwWi443g|kw50%_ZhZtLbYOdysMI+CaSSiRs zaqcpZbSBs^(hG7Y&mdkGUAYUZ7LxU*W9g+`y}*3xLa_I8fekh*(71mJ@NW{>syYY# z)K-xC-%{}eM++NLlL))IS{jiw5o@oHK$rRIuyBnU)(EHKyefC#`TWIoRjDJt<<3P9 z-ZrvHUIl|*6;amn9-#4@BOT|~7x%~HLG5x+^eJ<~%MuQD%$B{3tw(DVFwFD=d0)DO zDozcsJhCkjUn!N>5I1yt_Hiwi|KujEOfA22G7Yfs1^j{)>C5S)GUK; z96h-4EM98hUm_ioV+~$qDX5w9kmQ}-M@DK+w|qi}#b0vE+#phV@mXI~VKWMAJ1mj+N+Gan}pZ9K1(zooz5_ zV<4=vDx@QNuA>t_sBul^O@iO+J)oz)CML|(rjuzN-LjjDBhEgc>rM-}6;-KKj0zhm%Ezu-xrUgtG%Iz$t+Cs>V!&{cS^(lScFp~R)Flvl9Qc`JJy!9 zgN&_RuKFdz6pE#TcLzGkS6m^7Qq0jdjf28LdKjZ-!qsqaMQ6tzxS;bwnq^)9JyQ~( z{Ao1Iu(2S{mK-o}&_%_2et5ad4dS&sl)&QNr27k5&E4CJ+?Z2N3o8@Aa;^pNv{vK! zJBu*2w2UknXokZpvuWR7*jxw2EO@t1f%9S%rQu6vLHeT<@=_%Z)*Zc4a@9VE*yL-J zob*XW7wzE?wciw0I1RyBH4n+mLj!TaBVUrZDU0r%Jf8@MnUH9K5u7~ag=?3uLHmTK zbhCQ}&Gyg4)4qu?YK24u@h3@cG+))Tb^DIvZZ2y`ScSL z(;XAz4bWOQ0o438DbZU_jW#QQM}Ik(&0ma09jsD%W5A9G}HjUci1Xpn1S|s9w*yYq)xr#2=nTSSSzOtOifRZa(LOfML<32 zl#vC!PAY@+{z^!!ut)0~g)lwH50o!S@TH`J*7yQ=E)l~*aSvFmeUnbw9|tMs32@*= zUl=qf1b%&$3%X^y=);3`!gTBP$JI-pWjj$&7~@FAq=%)1-e_OF+3gmiFi9 zfM{77bb9THBl-to(Ys0bqHG!**;+{@L5`TF`k2N)e8WALWQc#=oee?f4wD70%gB4r z(bT5PY?wUN0QS4hhOWwk;6KtF=$NfWZI!>^qwD%SOdg&HlDd2hHCBalD}(V>c`B8w z%R#rX#q_rB2?AJ4-j4nmV5N`l`m^@rV8^YP5USlD@`8=aG2iH--V$<@imm>!l!PA9Xl2PcIaI>V0m zcr3+JA?~naTN2)kD2Ge#@GB`sIOtwtG1-d$rVQC9ic2+ndUIETE z+~ckvwv$@lHG&qVFt<3_v%m&r`$Ok6cN z_3%v247wml2))9-m&9Ad-^V-_{Iv(RI!Lex{RwQNCxXkvb`+B z42DMMaCL09ld<+e_=iORs=#6TxTcmCt$RtH?Rmxxt$RRxuDC;&(YAP8>j_m{6UI4w zU@YlpN2uq@BxH40ffAi;QgJ&F|BN1q6RvbY&0&i8c4-!D@7WD-(L{PixfAdnwS5qS=Q-h{D{h(ynDI$)h*d z>Ar9w_jK)4eCk2zm4G~q(&Cd%-TBl)PLIStV#A8TL%FA>Y#}+Z^E%$QYzNEOavMJk zVjB99Vr1d*gLarSHwo4!ntMltSJaXv%y=*k6ziZ;KL}gF)RmH+{#0Kz8q@u z;?Z~Yc4_E_cy71XE1~Q{5Tu0qLD-Bnq;%PEsQb$wy7}`-XJH)v#ajzbv3gSDb6hY! z#9-x?AXv*uhYBMfP&$$Uszo1Y$T4Ra9mrtc8!lu_ZGX`A(1uG!eQ8qFBO*U%8OHF3 z!^>JPdL!Eyep_)~DsMCkGGE!k&ZG?N=`{x*zDvWNS6-2&?&G0BvKq^h`w_jfDi}HS zXXv(WA)cSRUK$)$OXsPC;rVNGv2IE-p5D9)r&SHWXJs}t#&8JQc2Oda6)MOUfd+ka zNEP;RB+xr>IuzZ0OWw@9N^*{-60PJy>80H(;Icz0S^C}*x5p14`y75JIk#3}uDUXo zm4%Fp1!GM zVoE;tDxC`#m(0b#l*W_JBR$b3oJ)SF$p(ub)KHa^0{%-2LA5l4lt`n=g8LtetL%(n z{{u(NttgY`Ki^1gWamjqU1xB8{c`IS4J&?TNLiV?lG<}24$pi zk8gE^T?PZ8@8eade>{x5)^#RF;Gs0F#s{k-Odw;Y2xs*WkzDftywx=oM(pQcPrsSC zoii0~zlw)>7j@{PGwI-S{)Uu0qMUYe^~drzQfbQQPGEXEmugvcf%*$?Nq^BQ%;{c; zdFj`v>95N<`hD`L+!RNQ+H!-|Z%;(6x;(U;mJD7u>WRCb6>JV7P}nt=RxkO5?y)E) z>Wg@gg~Q>O^Inkr>=x}@Aq4*)+kGp@W;5zaEOs|WUP)edv1AwH-}j5x0_^ULmdr{0 zAE!n~Ms_L6Fx$K^=*ZRV%cqU)fNC{?nzwxp*^zIcIn(arj`n6H^KJXKqHVNY8)FFU a$F%bvqn#+4yj|qgHC!~9*6$|NS^o#tAMQf{ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..578ff6ea32bac53095ce9970f6190dc5ce4c92de GIT binary patch literal 4685 zcmbVQ30PCd76t)@t876O6UED{ieKJaJ!;SJ(OiR;(5KeDCJF$#>7p{B!>M|7YgjXlt`f z;v|wq337={6f3loB?u(a$hd@9VS;pMUk5u!N4v3{)mDrPi)H^=oJ1-UNM!stxh!5T z=B#Bw29+d!SCsiv0VQmL-Ve8w1wBweGTMG+A~315XI6H1w0 zP+7Y6t`QM$C(-P>w9rs}G1t_YcN>{uQlU7KFObP3e6b)&n825)a<}_JD2j@fF++N7 z3=;{YEC>7E+U$M0F#J>y0h(x5_v2?S}2H+n)FdI;OK_O#oEP+!V}`8agj16 znfZHU#-Ph>9FDI*<|`NT9UL85mVZh`BOud<7o|1F{SY%q>B&^|KU;XU2^|0&udbTrdzXGU9v z(5<)1-hyzMDDia#M@sCwz4n&5o`+B@=c}Z3Q&oeCxOUfXSM0C6LR5DCJEZq`jPK^* z*ej1W(Hctv~RZ~ z9J_&7Osqg8;YTG1;-lFf+Khp^+lQ)&H?rZIRb1T%9{7)AX~G6pb>L0H_U;DyyG5EevTKL%%u%{v%bCf9U3-e;#T?*< zKbnEbKbOI7A1r%Ub4HGj<<0bF001vb z1}QRASZREjhU`m)Y_f<&N47*+)Zy4j3Jwwt6*aKSnz8xgQX?;n3$gi zH+{T7GscKMX_JFaTM`|^84S|GT+sM85532P;il&y*wGjThqkv+dAKv?>fNDfw|?WE z$}q!wSCb+9)IqY`y^y@%&7h9`k|A{dFxcmj4E=Q`!oQTo@L{#asnlxf5Tg6$a?G5T z4)UiPFxpBFPL)JpXHyPUds>Vhv+C$CCf^aj2jtm|^_V!<4)(abphs#pLTk7=&b1m% zJ`0ZktMQ&7A6G$lh_j$Vu#zh8SAy1oY~nkf14lv>_#n+3_Ig@C?4G@}<@Xk%Y~RLh z!XVhS&Jt=}mf^EYhSc);Na(m6h7n)8fHdwQbt_zePfyQAE|lUmItxEvc#62Pt`m>v z8uVW+86elsg0;*h@<;@@MW>zStVxGi<{9wAnK7{E{097fLK5scH-Iiy*rD_OHgZ1H z3iD!e$%#yLH05M-qZf`R{$8u_c%&!n+>wEoQyU3q!Y5dJvIv-qiI}^KlF6%wfc3UH zxZGqgthO5m_T%l*J53W#y=vudp1O-VTs4QC?qcrVf2QJj{l!QYIH9D>hMVZN8bH_5ud30{ubbkB)f%r- ztFOZ-yj4+(zg4ft=owkiIT&S&K+f z?hIUr-zqCdG$^gD^`-sjZB+gOFxr7B=X@3@F@CKS+ji|h}F_TC*2ZLN9e+k z^F|nOCI?yyL(r;mJsos67!Sv0QcstB+B)|F4QyXQ#t+KD+gnFLte+)}eVhv!O=_qp ziszo*l1~p6MNedZo@ZSaL-bxYwc)r2bz zDWDrx3j3FQia#I8g81S!nDJ1G$FFUGqnp0q9{<(?wAZhP+fj?b*RBvOOIwuW6FX9P zp;~!m_5fg?)={m5pQv`24?MN<_UO=&!|LJSQOb9|O}O02=6CYqNEX$ZR|-cJC1|uXiS7@}gHK*E#49%iJUfH9yWNM@Pt~kb zPTmvd+^6{*sm*5NxLgiOCJwnR``UCynHTSY=Q`@(NqGf{WOGHPn9iR^V5bXI;S$mn!f z-ZzC7uq)xyYy~vkPk@Dvw#3zr1H&enpjK-Tp6z#u`0R-$Ms5ncva<{oo(k&KIs*?| z&Ly#FYQS%D!3*iXlNzT)IPy~zIO@6M{a0rA&Eo+4 zPpS11Ei5}Wod!9FVcjzkCYS}mlm-X5)|vzC_yl6@Qi;-=p^)KUj?Aej7@bgnn+)yY z!Y&>?WLAiI>xO}mdjwAI$3@2thV{q%bI{Fr8pQ9jfweAEaM8_MjP* z2uO$N6&~m(iN{TTKi1o8szF^@A;|j@eC6i@WuXg+&p+L{Dc4iTW}8{`%0q!kC8Aa=L8~7rlRMryF@oanROCq|Syg&};$JCRO62 zXEoISX)g3NyFe}O^Fe>xY{H(O0!}_|u*GpLT8+sD;bj8b^%i50{#x?#haCKlGZ0_p zWRU6;BJOOZ8CER*Rp~pYZ~evarC9av4$_iqfJx&rY0S(kL{g9m+6PCW>G=b6|BVPd z(V{?KTNZ5k{16fF0&we5B)jz%;+#7=z-nt##%9dL2N%83ZJ7b2ROsVP$$VUT-4po! z_qgu5PstBuOK|jaE)nOq&>F*)37*>we&wkQ?)#VS@6atD}eGY z&^ts2UpD1Ar>6}=2NMP82OCi`q==e-r3qeR)WD0sgM8UX7sER@QT0crVEBlm4D=X* z`x49HK@kuA8(r|MoP(pwO39d-G@Q}q3#QhKFxN)|wa@NS#{9ej=gI8>SuOR)G^|Ln za3z{IxzHP`O=`PS2;FKj2PX{LL$tFGlI-Y6WoN?$`lP=lWf!S;uV}r#Z}S!%n!XM; zIs}8egRZK4tLYWzGIW?ySHI+<6a$VY;aO##GB(+beB9%!nwnmH8aC|kpgZqgB_)1$ zsN`%XS-5pJ+#EOxZ3I5fx7xHIo;k@49@XpLE-V&fpsmgt!1U4XqdJ+cGmOpj`j*gJ SU$;I6hHi$;yYK!6>i+>2l#rtU literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5ec93d3cc3b53373ff8f5e25fb3211942207241f GIT binary patch literal 5657 zcmbtY3s_Xu7G@A(I3dCj1@V#coO9-J&faGaiQywGHA*e5uG(W7WPni^aAr_3EAx@v zvNS8T>oqX}VK_5$9^1NQUMuzH1Fg83544ih{32gjdCS*512Ql$nD@Ig-x)Z2uf6tK z|62dr54D<1D6!e**~@HBYq7=Pw3}^?{1SVy#qLN92L=cY?C?e{F@hwSF(o#K(`<8^ zO3Iw2Wlj^D80l|IQn&TeB++w@xv0!ShsO{(p&?lnsxV><(Z_1@jVZL|3+cllEPaB3ESyQoF^Ws4_V$Mfr($bQqr$>3^YPh)7#mv8m8v z&T}M$`XbYph%U2P^Q?Bu44)b1q6A;ShWDCLQfw%;`fxAFcM@iz+nI^Z@d9N3Rx6e8 zbippBbC(~p$B167f%hdi%wUP`wUGVVwFJgwPQc@m7Mv2DQmLex^%$UbdN ztys4-cBC*!B9&6;w2oJ3dkZdi&JqXlC$@AB!LYjmOKmU$2b9`&0YCgdsRFTW7ikB- zGc|6_ynwneEQ3sDii=8bpMlUnZ(@7^;Y>>D$X3eg|HVWbaXaR820D`^6MbIz9s}E@ zNXB;Q^TS(0++tuLQkhB}Zkq!zzO}SNF^V2%DJtuLfZB(IudQKH4#H(ciLv<#D!#a{nDBi*?=vWA{pdEQ!SE#i(yXOV=1Gh zw3MT4Pf&qJ)-9DE79L>e?^SA|MY_uxoU>cCBMsn6KEZoLS|u$+n>FR|V=zX7yG#0jeH+XdQ&>&9-5mjPlY8UC z@>bC}bUwP`J`Vc4u@=hn|3v@^E@8VGa@Ep9!+a&Kn+4@TUDmv{+@vOnC@Nd#Ev{#pJ_(3R;62SWsss8}& zLvw`-cxn1dYKL~Y6ka{Fx}2RR7?>UIjdtQB&RLRFaU#AFSObjg+F81H$ui z7@A`kjb?KL>&!Ka;WAu^7r5T!x29f}4$CVH(<=9&zw0XTHPMv5QW}jXXf85K*zok3 z(nYje>?0pSD}_(dF0L7Vh@Y;Ba_`qoLQUc;^vYVbYo==eXQb{(`50x@Qm9t3yy#tk zUyvuEXZ6vzKn}+Tg)GlG7zQ@$z7kHrA@HpJ8-4`S2hDd!8=9yE@L@DcgmMrH56j4j zYZ@BLjTOuayagBx+AMhBd$?ObhTWwALgFlXlVK5_v3w!BUFcz)2`AT$SI5E8zy zVTy3w!+QGj>*0ECC0-8h*6c@-=o$S&&|9-QWej=)-3$1XTz#S(DerYJ^HwphrH^OE z!F8h5`(z&v{bTJ2Y-Qip#mi1V zc_>6Whr{IKB0^Q(Ucx*sL)a+rX@r;v;-RjEp3@8tHulR6_j1hv{b z@D)^FaT+`(4C3sl96pDB^+v-fLOfjxTx@2oiFp%qu+WtXo~{0bs{jk|QMnX*nGaz- z&ZBHXGwO*`sup{v%0I|s@GiVeTwSY@Mr%GcyizlVb#YU;L(E5pd|g!4dx$j56PHla z4WHrrg|W#&0Lg?oamREaHDRDgx`fnQj^N#V5{lMz4+y6KGQ#kk`$}3<^-+2U2tjsf zL)}}VS@#P*&7GA0K~I*xr#7H+VLFO3KBrr)y}Tv~d@B}I4|5m8VeDGCNL=LM@b6MN zTkNS9&0HfA#gC{)_&EQ56yPaZn3B!%t}=>B*a)m$YEL{I)hy zej0r(Zk0nh4ccBa2Ypk$$kQO5sT#n3&1bTQn4LIsOTDK$?jqkP9`_jWcX$js&pZQ~ z!S1@{?l1LMVFt{H7d&%0`u8wpDI2UTr`;Tha}I|cSv%M>ojK{vPAP|Pvi0)&covf? zTx2Hl`_M;dAM*VXT(M2QkgBkdDw{><#f3 zM&d>)$$CLpD;-c4m6!1A{$V(_KrrEudGwIaG^Nveb@6`z}x=8sU zc@KOI6reY?pISp znr(QP_^UXub}njgt(5M?H$Q4$_z-+s;A_6&Pm5!{Kgve=W%gk%Y)g&16z%cc zfZOOEh=dZJhR%b-vY^Yy?+RJEiOJ`LtJU1vX1Rh{B4qQ;hWpS&al1QRX7KCm75ysd zH~mZUB$nBHOfpGJ*k^?v@;GQu-_4wo&rp56{dH<_9RC`QLRzerCNureUQdoNci=IB zLhEWa@~NzfeOzBC3~`@f4{sbU9z%~x*Ol1s*QMbS_=%h-%&m&Y3qb6KA=ChTqxKhl zlj{@d1%17q+jPK_P)kp3NxTlnxz@WF-*GM{YX@q8Glc|q18!KqK|UbQsf?q8Y}-Q8 z@Ca7~lk4>|+u;4;j6{Z=nnw4HiO6m{oad<9hIMQPIcn9-_O5ocA1iaRdKoXGcjZLP z%cE9P?g}Z}Tdi$V*zs=|KA+M9edyl7PDAg(Nn&ZuQ=aKm&>KdQBx2MA8AXJuLzRc{ zJ=3+_+?90_j`IoU(nQcLb0^sr#fof~+X=EC5vLdsbJGBy5j|V<>WCos4?(1TntSGI ji6DpS;Hu+t^+d4ab`ae@4|`^Z#>8gF5`ouYF{=Lpq&g-& literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..11b324293fb3771c4886985f1189f8568a252444 GIT binary patch literal 5851 zcmbtY3tSY{8fOt<84+OwLDW(}*qMFIoHJ(;OwF*;w$#$Tsq+0t18wcIY}iF`OisQKfdL%VKw# zOUqql3AZ_TIRGBm)aBg zotZ-_RVrdSk!ZC~E6uk$I)okPM||keu@=|Z@?ta0$|CP%ra4NTPP5-(W`?zbxiJ?& zx=(0<)W}vr!hXB9a>zmLRkbkH$y#z(BJJ$;Yz@Ls7R<^f{#+}0M`x`;R40acU zinuL0i9YwZTvjEQ6^dE&%o(%HH&P>^X_u|D$`{=m=p7&x9$oCP0#Hnp$cuSQiMb zYoUtxA3*AqHz>H)XO>$XojXfVt+y8mnKZiUj3VP^|~%Tl4FT ziU?{If$4yH;?MeJu!7bm*WSKT_(3K@lz)61U4G(Qx-;3kZK##OmcjOo0i=>jLwDK6 zD9zv&9in6IIfK3J>`O;tvxfw47@^ui(1vkSAVBI*s$dXrirhrGW5M0Jo`bR^JBv(X zN{Y*Fzpi1xZc0*68JiX2a+PNNf7Q30xE*&o0-eK>DgG$@KLgvT^~QIa3&2}S++txc zQaMU@Z?gdy|B0Yoi;Ep?EiP||fZC6Qf2?B?0~yEJ5SgZD^a=%eWW%OeaBZ!R z`@H@`wggR*;^ouQ2rkigSbW;IE%PT~Ketc6LidXM3+^!6$O*uULIJ{=5_%fXqj%&K z%*%JLq&$^UuCGShH6Tat!kR)FPulgR7w@G`jS8^HpG6QSxMq^O2 z{(?BtXGA%u6kNo6xWZbOXkgOB8M+u@N_rJF0gr;ialYXfp$5G#pV2q4JEafUHDIoo z&T29po+yx))(anyH;cwWbJ1nbQPBI1)p!|h^o;NxV{^S5(7y~(wLb`5N13HWX%auf zXYhk)yQi1?KD=wAPj~{((;dQ>#6tk!S@>OFle`<{>Z|xG=#)3zy9%$szf0L@kFLP* zgHR|XgZCv;{{h;IW(z;##hJ^f?b;<$WX+(O3U;z!V77afFci94vlFqXzi4CE3x(*U zbe7*Qjp7&Zd%a0I3mywnKpBb>7N;%d@2eNF!1$0?*i8A{oAGR)9iHVf*hs~~6=57& zhb!faXlq&@a0VUbmZQXtGpVHNX>2IkvrP#R_*{B7^r2@vI~lzP$BSjPQ@vdSlA{1} zO?(0L_MF6}D4yLSq}QdOruxsRR3V@Lx^Am9LW*Jb3D3(BXqI6ln#uLAx703#FX19Q z&;2I9dGJN)fV|W&xoR)^yRHgf70sE;q>*^6<~*~Ajm(@bokzRG-tvC5O!x%tGkI=9>3e_n=7QX}V3-WmMtUeYO%8~eMVVL(U zi~t*TUkS(IKzK%fj33VQMsq!}h9+trd<5MsLOBE=*;%H9ra|G{{enedq7Z`t8wD?X z4|fSj_7DXS5)Y%_H7vl>mdt0j3Ehm-;e`4z>Uda+hrquYCJNWQthX<}4zA;t;U(ZM z&1WbYJ)@rwdTCask3w&tdjX%Gr%#ciBMR<}YjLn`3Np1YY7rEL%S=e{Ed=w^~#S!vR5us{dPhk$1 zCG3!U8jb0T-C5d4tN*Dx;B}^{B0~ zWL>XKPstPXe?u?R_4r%dT}Z>$>cI-LcXE5-37rw{hgxk9d>J)Vo&t{x12_k&fX|`d ze6et%kVKaOH=9#uX5PdcEOHM9&(?g*Rf2i=kX(j+%!jZ6=Tml}8Fk0$)eC)-NHxhEMSW!u@F>0Lg)-gd@7J z+6YiAT|nwh2k|aG6~$`0282@p8E*K_b2+1_<`BIdgdvBtzWy!IqWcw};!eo_peIP* zQ|nQMFa^aJpVO_>UR;$5z7-2=vOOg*n_UeThzq|lpWxrG z=BtydjW;-Noy>=|IH6)pvNOFeiMu=@sxzwn2U$&tL`%=b3T*UbF=_)H~56O$7QP zwQ=t40aWl{`b1p`yboWfA0!Ot z&%hU97_Tzw&?;^U^(FX3_*^<*xL-@-uQAOpl@8Jo>~-;1M&d>&$@;UfTH2@VKd;~w zXegH^*T@Of!%Qr-$jhK7!ET&{3k|HNxpavZoO@u3InflV?>W&>;?kVHtS1JBrEk`IxQn0dZOi!%oVe6XT+C z+b$AJ>b7O=hk;40x-mY53VBCLl0=-EAY+Jdb-3~){J$yME^hle2q*Z3^JpUEak+zR zi(*B#(|rfohe%Kyh`Zr{--+%mdUZsI8;uase$CzUv_yy(>d>ka^7KTgSM3nG{SkK0 Q4Udb@jVFRH*y2?G1*ey56#xJL literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1211a2f29caa3c4eccd8997913ec945bdcdebca3 GIT binary patch literal 5225 zcmbtY33wA#9&bWR)3N7RXeffvGnpJSZ|0?!Enrc=s;EU>tVc{~+D4kRBq>n5SPpd+ zP*I^QA5xT-Hk~AMyy|+e^}rRdqIOp~N<9#j&#I`%x{A`Bv@}hc1Xwroy?o#N-~auO z-~0XFy!R9eTv>0kEw?w?oEEED@3fn2j_P{5)ogd9j|Y01)>B2DDy#&@$=UTbhtp(p z8tWUK4UJACl^*ApWhh2^2^@=AYN~5A6XTPyGI>;Si!2(Oja_Q7`Nq^*s;bO3qh!Zv zc3@#_;tAxHRaK)dqJ6I!TMI) zw{r7aWHM|Cc2N;v$-W6FD!Sg}yuPu{NEhbe!!eOd8qM}5@|@fcIzAL|P-n*YVI|D^ z2VEB6lZLG>HitS)9^2)2_s241C@_-Y5K$s!7})7|*)GSaHq|M894ty$RBzQ=EsN~+j{0gRX2K$dHimNyG>i~HqFKm%oFl+!mwz^H#S(_e z>#JD?!V&mKRY%d7IE<68E2zQ~g5<-c{y|8oM;Lv9C4eA(U^t8yu@K-<(vvh48t@05 zp0r@GrNO@l7^#ejK?47J5~lvgPRD8)_CJIZGD8LU^nhf7%6~LmWh52MJE+68T2Syc!n&}h6 zEBHbCMuuI;1~!&D=VxY6-Q+WPCTXo}kiKN+`&0!mlpcjNEr^gtskBO^mFxdvstvoC za6*92$J2dT_?n&ySwvDo`uyA zR3)ubk*thOtAGkfi#|_UBYpm~ZVL{*Fd)kjUH>%VY5sud;N>wfxMO&Iy#5^BMtOim zGJvN_j1Sp*rH>!-!VH=}jF*~cW)TrXnhfoda2JN~e%`iXW zbD07gQV!YdlE z`aIQcRKTC+FY8zadtFZm_dFX#Jp)hhmGom8G9y;o;YtAGAWjSDHh2;o5;<)(dX6jB zlDm!yMqv$gKQ~UCOWSkzl3$3o>uR}a?i19Z9Yy?6c#CkB{}&%Q%K|COZ0-kgKKl-Q z74>vF;6hb2d_O}*AI@5!wbECkQ=QYf682mAAzCGT&ov7N;AHkuv=J6DnPQu`8SbZ4 zJNsJKgS*hpWSg!q^GD$%eFV*5)}bbFi|a|$Pi2Uo@>iijUftfqeTJyZYUmttBR_$O zbEmjhPCHKDNN?fT_K(>w`8nSCncs20wK2^DVhg#3D`f}tSHlzP3EoLs1wWVFfa0MB zDTH$JU2P7kr(YM-xn-?s=su9LeVRH_meRIP%Cfil{nm+-z51IaJoa;{>4b2Mi(+L4 zE!@mhs6PO2aPJ9+^m8<+>Kn;K^%@Te?*RKy8miWp^4r->(;k62$ph{^j8UtAo@ZU! z_05$C(0M>E{E77Nn~C-8DGi|>2dafH(0$%FJgX4{ON7kBzlrMquVI_Tla`~12qxu_mpwVE<4Jp;nwyYY&IpMCZoZQpLEAuuM2G$PgGZ+bngk~EQ)1F_lImu&O+B|G2W01ujx!iZK9Jx{M+1h@Xnk%ngySc zzcb%+tDuK(gM5aZ85Gh4pjiPvM}mHzm`SO`$Ke3E4$h~#VK<*-c!&0idpxJWPUblE znI{pd;VeKihu{Nlt2)WB2ArhUqHI`3Z*BK7?{&URS>Zw^6EE^8{K4u28)T!67`>QYiavnK8i7BdXEn*@d7SMT^{z1l=Q z&s4!E;X4#9_VW<7dK0+iOdhvKOfVR-*1GaEH@E&>d&uL+RES@@`;c0r?%b=|Or>XB ztNTW(wU9c<&nrv4LcCqO75<78U6gt|qY@gmQ<%trD@NIs1iy4IBq?r(dYX#w}4BVq&c zk}uOes7h_)24D=zYF+JJD1If*MsJ~q`Ar=%VV3GO{bTJ*DHpSVIY_>$uhz!5z5sFk za(<1vLjN|phMSWa4v@mK)T7#{_GnNioPvt3!)Py?0TWdb0V$v)KuYwVy3gbsYdb=` z38J7~*xvaJZ_@sRzGRMzeESyc>SRzaq*R73^+b3O{aJ<>?hpv`(fzVhgE*CyKTF8KwP>r zl?byzh{{1Fu2{0t>m_&7*YJzdNorwEsO8b4MpDcEOXegr4XEJ%r+?kucq<}>MtD+Wlu|8r-(>BtkHA#Wui{+uN zA}T7BIFGk2~+fvXyAw&nI_o7380F*xmJo1>=D-e9piGRFf0O&h3^6B?`($H_U3Hiy$} zbDA2PolVV76P1}58p~1yr38*8EH&3RTZr*#Sh+lIMw=`in}c0wwfX1NS*xopHj`w> zX>njtY~sn}W!2RqOT_yhv&Kn(m`8?%3$YZl(`hr++YHMr*4jEJmK1!hvYH)u4teI0 zyxQDiu`iIKGS^#In4Q)}+sFb_&Z3otN6sXaSiE_;)u9>+{$y=3IgyS{urxWW^^LZP z>|6PTZ88~l7j|AAU}=5=N=mLVJFjW3H_^of_^?mpUCkDI3wdTd1UepYI82x|epm^+ z6-ypA=EsmOHezS9gsVnd3PU}0jAWRAhCdik5k7g?-V#e>IB8==pICr{7DmG_BpJqUy&Q`PITjnD;G~eXkOx>?;7V>XJDes7 zT$966UxTFv9s}BefuNri7Dt`A$zm!D&BxQ4>=uVbD(T>491o6+%+I{wGN)5g@|O*j zg-c25w9dDft3#ti8kT5lZZOqJgASG4Uk($olxCZ?+G@8f_RlldtNc|`lDxRF!O&n` zY;SZl);KY9FtkY5##99Zjlm?63L%2iGf6Ze3Nd=q=!#$<4B|YZI)cWe;Z*dC!q`j> z8z1KKgNSAnWA-_k6hZpIaG1`cRpD{cl^Spjgq+SyUAV;B6lzlvL6gD;LLDzkhZ;-A zY8m!FxD&xp5k7rXFkua9#9dV^l@MdLaK;L;b6dqPj>Sgvlu&^M z$imV_RoL)IT(}Ftd^U~*KB|bo`4!PvdK}??My@fIkwS=HGCJ@mIiEQxxAnAdm4FR9A9o^vF2pna zQTT5I715p2BV-|X8?m!Y3`eR!YJkBt0O{WY40mia2 zHmo9QR9N&`!V1blVci_=dhw_zgSsIa@r;l|ba>xQ2^S2P>osT6Z6pt{SO##Vl<^U} z&dl+FRhWP|9BVLJZKhhgxv7pI6j&NoHfB>ZrFxX(v0#WGy%=&4zm@~E#Yr-gos>J?u=OGq0y z*#)5qQoO>5Q%ijJ!;1ldK12ljm~fDYr-t~SNP(Hj*hM5f$#ilrqUD;Y1^|<~pCEgE zI_?2CE?y_@q1S_2_>@-O-Jlg^FvJxPDt?#!_n&LCuAVIg6Q%@&^BU~ISGwR?brc(0( zc!PUSIAEBoRcfvyQ#5P5B)k>uK^dsVP{wa#H%@y9=A{jKb}=TM0(zfu>mF^bLVzv+ za^X*;m)}G@%AU{?nsJ~;_yXPId&9dLF|brvEq*N;#fPcunfUGwPZQkZ9im?)7Vtfu ze33*?P^a`8gn#QF60fBu!M)ycZrMeLISt&>v7ODOWYiQi)cKSCi2HTn&-z|H^V~kK zsxt~85SA-%%9$>afMyQB z`#cSrRO1?OoLY-=VL82}!^gbW^%B(p7YXU&F<~Y%(bv!4;p@r&hTG5V({Io{?0%o= zr}`KUcwsa^3d@zr#vkCbVkTn6SsOK;HlfVt(#8hFUow<9fN1z1bQhH?*0LP>r$=LW zfE+{)aKv{Sj^oXmY^IDp!d@j6%LL=?x;~&q@c;+UcvR%Ix%=p;;79SXj`<)%{2e^v zyHV^B_9Cy~G)yL^bg>;r#u7T!Y5nNoNt5K4gHXa@0|VHH}uZY9;j zjWga!S9Hx+q|=S4kp7EdA$QtKc_*`5>8;G;XdSpz{Vq&|ck5SzRQ1N3Iq*q%Ibd_D z^qFF!xYzTbubo_lw|WzmnRS)^8>uYEzB-t~f^y}h;w`!@@K>bhrZn3awa~1a%EShwPy!^I zuG6G&E`A2oNy#cRuGOgpnX$dgjaI^6;1n{Y`**Zm_|2!%rFP#YR_OlD z>)g5EUf1hP8@LA@5}S~Ze2MNw)fyW&2oq3F`)c1J@hfo-dJ8?sZ|sx_v(>K|9_d(0 zxtWE`e)1JVjV`JEd59a9^J_GfhPTm`+}!MFfE1T259{JO;z7M|0xG%>qTOs3Oi{-K zq==FLDK&iRIhA+Bb%=Nq#6i2Tt?OyttosRl$s85`MpOu&Yqr4_t`a60@6&D6{=6v* ze8ShdiaZT;5!Fhs;#YYY^o!6!HF$e@Gt&=w{uNC>eGB_SJKL_>F+kZ-2{)9tL8#9e z0G;4%bh|c@sMvYGpaomS&Duoq4)_uOvM6WNu%}}w{Mfb1+b0}rpF(}a7E}AlT_~}; z*XvUL!1nV4UL*Px&4I_syTKsX-L=m1f&MqTfUcpx_b!uO08tOta2!ihV0aQHSIDJX zCf!(NJ{EpXfKS4dl2Y0kCBGsj-3gM-h!ccZs_LR literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3ca5fc563eca7d815c62bd6d9caf9e1e6238ac00 GIT binary patch literal 3161 zcma)84Nz3q6=nfpedYIspB0V6KfK58zIWfdYn2gV)*8(ki7_!n(q_BJ0vp&}WET*} zwrL_x#%NTMMlwzm6GT|v?)$rrW359IVk=0{(q_cpsYDZ|)SqZfnu$suAmH+2^WNOu zx#ygF?m6fC?s-?O#x+idV`F8N!)>!$_3lcG!&U68v|B4(2_ZnwFnT(BuLjG(acZg4 z;c{CXZnLw>T~XyW(+RW3Wy$I>Es0~Zsx9SJRx%_CGX(}LZd3(hOR*wxS}d z!>ssmTV2>RKJjqs>7t^EErN%q$wA5w=2c<-N-VHj4{S zr5;YH=Pb3>$~+}2OSx^6#cgvsCN`M+7^`G_>Pj*O3$|>uximv#kE&4>NJe9!)(V%c z-06trPGtPNQKiCKRoG7o0gHRu?Y1iQSHhyr8_KH957UX6EFVq_3&1raAyaFtu2M^d z)yz&L^Rq`*B+nk7kWtudN0r@NYPA%(G=WOI+2QM*cD>!UzS8M(7P~PEHltH4YM3C| z*{dw>RaNC?09ky_XaEWisE5*VN<>XoC5McfU@v3gqmG901(3mJ_JYSvc=2OHxSxEK z7F|2Ykq1C%#%grQ)f>X9#&*(%1FVy2Vh`~>5qn}#Hlad2Qi zC5$PMEWi_nqaPl6dI~1dQ*`5?II+hp^asqKz&^GIphl=LQG$r&*77Q|f|XIBs8Rv- z@c*!Ea)Lz)>@#_plGdU}(qbNy)|hTQt#$sPFMK4+G28Lj@%Zt882{cA;jb8_+bV1_ z{@6?o@R*0k!msNkgpBOMjJU%wyT#@(msDCRN=Z_UMPbHCP0bV&uz2cAW{%kUH9W;~ znujDtEd9@(Ddy$WB~p6od3qLTM7eEs@(aRBiu_Wg#zw5C9C_18RYy8NO`L{5C)G=@ z!$X<1%wNO^q5viE1MCA7%2M7hxLN7>p1X3GAp<_!7KNH+H;u%<3D3b_rvikCpvCh#Sa^qq1Fohk2xgVc7IN4LXvG172~;pNYKcR&++mA>MO zfCO9u7`6}o&TA(k4O_rXx(=m56SKF4XHU2NiMGRhDOw(oa@kn^g7_Q0J>x&ZS++-4 zuifT3!(O0!SpoRqbOkY(V!{pIz>nkv#K}*+Pk0+8Bj21gb3qm;1)=3jh=Kn_8|XB- zgcHy|yo7!$HHciGpML;@L<^C^8kv4>C7jEJ8eY}*f-DpaaInv-p`K6cVV(rH<((}n zK)ie(yv?tb+odk#lid0_RAS3ha0S%p?ubix16%~1;Cs}?mbADWl86J_Z>$>R^(oy;Yut`j%Gczi^p`b7&67|W)L_^Xh_@nm{c;ej#vd|-7LfEA()_*OONE&cN z!gWXCDOe-iLUkFth>oo7Qb==Bb1l72&{G}W?Gyp;H=lqsoG04o{Xz-6Cf(r9N=vyn zxKqA(tp%+D37`Uo3Uw)U+_E+i2^0@~!ZCvP978p{gSo+`(;>=K~dU zF7-vNomqzNv?U2S+;wIv6TqnqTDXTTApQv?pr6e|& z&6I8l4N{Lnvu&szX0s`Bvm8ULrXq+fJ_`N{bfS1vtT&34GLqh9TZN-&Bk`mjz_7!+ zsXefZ&TP7Ha0hr5y+k$Z7Ny*l_H*wiz6H~x2ECuLX000f_Ib43yjIxi!R6I*C-V*{ zf$wDn9=^$(N4e~FR14O6cB4ylvivV`M?)b3%mNT7{gU#Dd&nK!oh-X^}UH$m?OO@zqT)c{v#X^ z3z@gGsN_&>t0x@H_C|YOPrS;kW+~Tc(oxUOFI2i%{-EpD zu^;#NG_7R*2=S)9`k5AYz+UDxLF4U3y&vwEd*te-7;>7PACQi6JfTz}&r`>lXT|ji u6rG=5KumiH$8ju5jp1QfpgK?)*zhPMv1vbD)7csWqoR$`nE!n;O7%Y|iKLeR literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8177298b24fb6d5dea6dde8e3b4a040f376aaa4a GIT binary patch literal 7901 zcmbtZ30#fY`!5ysYP}T|M%qxcp;GU8QYIxWYAjRaYPpo^mhLS@DB5V>q(v&G8M_&K z-SZ#~W+oKmU6_@BQ5Ooaa2}d7kg{ectn)BP+|%2;=jk zBO>|2pirK%Fv6EFSQHi!%8L+a4Ky=0H#au5-b@%<4##9_7+)at$Wcqe8ioQZ2pButAUd9Vr94KVu-`)EK`3g9CCjZ=;xcJ8K zBHYA&`Gy27_Z0?(@!wA{>|dmGIG>IbdnoN29VF1`diRkmlN?~E!VKbt3xYzz`06`< z9&K4BDakBh)OmdWumE1fM{X;3$K2Z5)mP{m8NxNUFz1Xk@s9`-2)Nw_n*{S>O#bQX zjrEZJ1#||!6=e8dI@cu5s1H^3VD-st8lMyCo$st9^F`rhhS&*(JaKl3LvG4l8XU#_ z2P(CX+q)rSBsdy9aF}@V1cAQcJnn>Epg8j35j=rd6<5FuS)?`igYi0Yy*Gx+j2u5Q zlpDzN4G?Gy=tf4JQHK*V2a$3Vqt&^5Wbly!g!M|G^i- z4^S7A>fczykr02CdYEd~i;$MmH^D*Sy^P_S>I~`A&`Z94$nr59PD~%TnL$DP0A4iL zl%xDE3&bpWn+dXxzQRB;$4qqwzWrx7_#cK!h;e3|``2W~;k?t-!-2lK%=*{0R+9NY z1nJ|vu`un+y`_=7h)j3isNzHo5+f!G+ z`XEpGBdW@~z;;8_Q`b1E@0#Amx~K6&3Y+P0#HCOeZIOyo0PV$LUlw%}-Wd(nLvn|@LxRX*wK4R09pF9Z8VYLd7PznjC%r2AR=K`~Nv z5B+HQnk1FW<@b!pIrgYmj;0 z)sSM}j<@ySm3fj_?f+Gz3{jSgdb#2-&4um)cb1Mf0zSl*F~ z9;X)5;;rF0*oeW3IW5F+hm}xh9%7bRKHEc;fmO znNWREja*T6hv}V`__BHwNwH2vBc~?nB)t-ft2SfGwS{o*;&(KEdlB7{b&Ngy+L;Qo zS7AzGIm~{X0M48snkusf+s+E{$#O+_xy}OjHEN*J@^#RW{~P;thAZ5QoCFP{UXwq! zu7(|*Bk;U!7D()ygj3Gs!By2c>^Yl-;C+#cXKvVIh{_`3Wo`m`!6kUl_7bT#iowAl z(y(Lz4_lhEp~1J4NM&ZjsFQLax?2EA4VGwjtsIv4dx6~fNPHIAKwCTkoYx5=ClsA&0zn_Mqm2mV?j2VQ|24ISiJwfd7c{;j?I2n2OP=f67CLt;Lu!EeRq! zN-wXxp(-s6_?bJSJ33i;Di!2o4Z8{rgEj>%@I8DUyl@$F? zV=_c4t%D8B7V;zr@H@GCG%Ys?W@{(Iujj|Z-pi%<%cL0CeNmpSNHj*9gU#f!j~-?O zr<2nuk~rKtl?_}vk+@G^kEa$n!LA+2cs;(3SWo&25C4`2%+)AN-%ZKnOij?=_ARc~ z7z&xj6ToDm3C>86hO=FF*sWjhre?RaVV6S)yYG*9ysWeW$r4NC7a6co_L(r6dz3bB zvjO?JepQb;b5Noy8*fglqI1rkqAPzH2Awvi=`EE!D64KDlD8!2tft}kd9=*MXbm(Q zi}ynpGk}?i_yZ!7DL@_AVfI{ zWok2U#DZs3&o&n48x}&8-4A5O+B|x3{bpEj>N@q0s-^R<=U}UA4(>m&0DWI3fxA;7 zd33stJiRv@*;_8iKGgxKhB^3ZMg`FsR{)DW3!&z+B9{EU2_p*H*ws(X$Tzd2NnW}K zF2!S_veC67eSIaEraKsq9%&}i79OTk1??nqR3bR$-4yNLz7awsl0Ys=7vo1OfaYaY z^gN#ijXCpCuWl0^^3WTPho(>`+bnu#-W59M-WoD-NE)_Qj)hP+T`+x?4$1Wrn3xyN zp8GC~9?c7+W(FIux&1j&ev}WTzp0~}pCw7TH=Fv*DZq;JkEmU(E1axa2@k0>+^CHM z1;0W#7&{U#)vSZ?{9H_K7vQPerEp@)GWOIl9gy9$30fDg09WH2&@F5fk*|zN&Xo$$ zjk)sw!mUYSw9-VjxLV&7F8EZ?Riz~ekJI6}>Hv6HmCgyI^ zm=%#m&+M|Me;3Zhl4Fi|JMK0y{Z$k7Tf*?Cgg7ufe3q09NP!|PKJeD~Q{%Dn)G=@- zbk=7fzcUSoCArg+F*!6N_z-J%X#tLVwj5)d&XH!_RFR!`Axw$SCD+pPaM31P)N5>I z{hxT?{1?eMCRdK^Q-4D;S1QA$)3?}KCs#BWSqjlolkv(tKU{vP2->AMSZTizZkI~b;3c&+)aX7I+3zL?f zqV30g;6a2oti72>+E;7g-C^_LM8{O{sGy|qmttC5a+hSzS_26-d+5Tda_VIkBXl*J!Ljl&u(1#oXk!4~eG1w=~)?iYUBXJG7K$h8M;YOddBi9v_Vf7zF@!Xs;_-U#u+?cq83MJ(x zFC1zQ@eS!XwrLv3>zyN>K@5&Oev=OW>u0)4eJ!0;{0;nEDS(3=**Gl03_Q7mNP1Zz zyyC=z`jS0F=7kCT(j-YDMLDcyPzspM&w(eIV_<4TK0DNWCz);OkC%*mQ2~z72Q6)M zOT{zN{!=>}*l~}zUvz>&9uu%u?ID#Z3AR3Ra5fn?kx=J?WRx7D0M$ctNyDFhcss%k zzr8pJRlk{z zH2$%r9F~_Qv9pV2(%d79h;;0B(croQ()!{m-5 z&?E`19M~{q`E!~`+P-ppvt=`844;MTQ@iL5W-XCuEXS6zU>q}fI~nVe z3#;GEB$dJWxZy@I@=7I8m79Q`g*!!o-zBi}e-%L8Ie$nE^n&1}C8Q>28gxAJ26=BD z8OV>vN8HU|8>1=GJIjLJVFnww`NL-G3~11H2ic>Upt$7?U39_@Jp33Oea)WCY8wyg z)79a;_Gp^i)Jg`d$iXPyGW zj5`xd$$qQfNdAqDSfnh6bv1!(e_&fVj~N(tWc=WSut*gl-wwXg3IJ@84uk zu%76jna6Qh*$Srsb>w6htGyAF(p zJ4LH#NJJ8pH>HBCv;p#;m%+=)17Pc%hxFvtO!)P) zJXX=+>e6VqF`3TJ@&O^839ARj(QLyq7@3*~^^YT9sksrcGq#3dRvIXC#|tkEx<*{~ z1`<{KM7*)92os$W>GV4ucwBcL2~CgyZoMsDNqR%}TSh_6?~B1)(E%TKjliFtdE%dk z3SpwpT6%YSHd;*-KS9W1XbuX%?WcqYYjwzwr{Y(=eWS>=74@{dF%gVojex6Gh`--l zgK0H&WZk4OIK!Qfb0g>A^O+GaQ&}Cq9an(UUm25~-=83s21%g5rX2moJR#OCIUsMX z3D?>aM8mwRMYHnFz@;t~Ra5Vi;?w(yo9bfrsp|rq)|g92|HxsjWU}GaerfBcGbs&O z7Y!M$spOeLJX9RNSbf*zaS_?LMa2ZJ>~^L{N916}#$ll95P*{hv1neZQhm}r4eiyZLHGd! z*kC&aqg(EiWrru>>Q+yZ^nEtn=d+4*ZQn)OAK0=d73GM*mI9P$GeFLEK!;de8hJef zq9%A#LnTkVo05Xn4{b4cv<%(7RuktLl;9l6B%6j?F7V5tJEX|Y9Fr^kfd92Vi3rk$ zlZRb!b6yErCO)P=I5p5*?<_p+nFKRS9MO#*j$7Q$RGUakKvhu=L=Gg_<>mrKK1+$q z9}a9>Qykf9Fq@v=GYTq9%IS=e6?EPkC02jseE6+%I*im(#l&Um)Q1++ZF^aqapoS~ z?7R|(R4zhlqXKFpmcZedGJNuKKXva&hk+xmP@TtIP}(+^7%q$hOBZ|i&U^#vjZX#M zbpqQJSD=^D2GaR!8vbmphFxjNq~dfCJ6ALUOIF+zxqdsa`l|m*EdT2%X-rqfm``2b-#(B7PII#;G^nbpENU8f5J4YX#xe+#-d4q$Qf|$M)^oC6lnoX&yj=d^C&r>mY zL6jj1U2adlZ2Lz-La{mlYj-%(T@P=O0=Ea0f8jYG0pa=yYD!l;Q2 zQ2xh&ZX;xR^p0ctIYH@1x?8i1pApl~u|WT-)%?aY{T&$eqr2Nf8T$dsD)uT&-vd2m G$^QZ-8-+Xo literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5b387854f4b181b44c2e8dc125242eff0a862363 GIT binary patch literal 8074 zcmbtZ30#fY`|ql#tMyh?7-^HD4Xy8aQYIxWYAjRaYPpo^mhLS@DB5V>q(v&G8M_&K z-SZ#~W+oGzK|OkBaDg>a?Q2n z``dKo-uD`-vt?rg!ee;Gaw;r087Zkmhf0Dn-Req<|FCzl*pi)Wr2wIj&kg6BF5`tR4iT~y-rom?1_(I%W`Et9xd+7a zqCLfN1%!t#4-ke%^8cP-#6MW+ay}g?j!-5bE>xh|`~HVj$r@y=${xat5`>0F@-=q; zIM%w7#bPgEYw-9%k-@y^kHS{zTXQ=*j{u=ZOgPul%91nMEGRlsAmH{nY!=3gH~Xuv zf31fNY@j>%ogth5Z|9oK8S^33TdqEtP4jamAIzutE{Hy|EUhN7|2Hbim;_Pns4bh^Wh0X0-|`_iT#Pe zQHYA>3B*ub0WW-!_RtT;>&o}v7^|@5`7sgP5MDsAKyy%ECTOsgV)&uKq0zjcJ~slw zHTx38PBADl!ZaeZkLSomLUsWAvoo_l%8$v`9Iba=iqqi}u2eqm>gQ+<`$LWj^J6+(*%EZKhlKKjd2w8Gj>`Ki5UcWCCdj!22t&l0HrE~e?my$9e^mr2 zG0$wbfsrgYocBm?H7AK>F)-Fz7W@CONFV2omARzWm&Wj-KdmehTK~|bKacp5d1Nc8 z^$&;iQzNPMcP0{fC872H{&)HHkj0iL75&})@a+FpFC_uq6}b=R7YkoviUH)W`kt;n z;@?(hj@tXCcMQI5{ID=s=yJpbT^Q;rE(ZTlAwJGMfuy&e`tmWcHWMWu7z33Bi3i5t z2K}TyU`a0G-zI;r^2dVvXFW?~iH#*k&nzN5>fh_yvVUIf;Sw_DiivX)7wi9*zWMBb z$?4@#Vp3okt8gB6hL zTZz85kCP}^T-`s1@<~{>AB1K34`IFU>krFU(si2;qI_@LPa|h|zeARidpae_j(4i} zV82VZzw?0oIRK8jI9os2O==H*JIlnka2$3-Kq#NPI65FI#2Dn*DlDguHj8dXeVfyN za~U&l$4hK3`$qq}A7k=%wB)1M&8Xpr^^9Vt3x0nxnKArw1o&0z!1C?_^ggwimTir~ zp(bpsp3_c@cSy^z-y(dnD{#w|whL}%UYexhz4$`vA>T?1htESDcOP8;A_wX&s*@{f zUNF7K8ei6pA?bD*XyV#RU1e56S?y*_zqSz0UHpa?Z7-!ea*r{GU%OFZ-YQH_se;*$ zlfjJ>N;72FVCPvOK3T2=FV|V&z7|bXUcL^xi+*9g%=UnLF_WQr%xm)d*4411XB3`y z$OWl=lX1$KLb$3nhdF1z5d1E3@yrb;3|C!5d@apDKdc<@Ib0%5Ch<5lTn3g5;$eGR z9yAB^5b2ye7;{n{M0blJwb>diu2sR3AYYI_AA`?gnrXWafZIADBnn5uT8-;;{(&UO zFiwGkPsf6l)gn0dyc9GW_t5)!6Uf%)DzHj52hWyKu(CK8V{?TojpQU zJLi+vZr;>#$a3&sI06p1EQg`;R`4HD5e!7Ot(#F@RTWf>senJNg{bGa7Lx=q2yLfo zYc#&0($Bua%hT`C1&4&6C)>(gOm@>#9y)I=!}MvX5Yt_OA^J*iwm2A{H)T?(?jm%V zT}yx0{E+~5l9%3_Fm|3P>~(le>-TSlJ3%@)Pk%g_8WarrlUyNYVmaLrz7EO*)>6^q zN{~I2K|CheLH+y`>`c~yeXhC?v3DPBdDB8f_qH)j=nK0y>cL@$RrvCnD%E>s1W&IA zV(?A}5JYxSr~D<@eQqu?P=dGVY@D_9EOF$t5|>xf^gk_W5Tm>fHn6vlC!v7f$ls%x z1*tGwCk=i+KLPe$uE3ur$HVT63Uoz^DcT=wBbWX4F*_`aoK9!qNV^OsWa%X0HGMsv zTI345cBJ9;q()*l`3pS!OChkY#$wiPN^Enqz+l_gxLR{KiINzqxYBk{A+w6YYlHP~Vt-Y+k~7uB0E#CsjwS#X#N zZadQOqICR9s$106q=CuvXMkZ+4l%JVLn{*pjPSCDaFtY)ZOFz^3!YJZhXkB&TmrF< z-;o(>3+cu6n_oh2~fzH34j~yQQxc|Tc40xFeUalqN(dkC=^xjBhZn-1#R2QV1 z=isXu)kJrEF)a2ef%?lzSpMfGj4tkERzI~MU(Jpqg<0OX6px82$2N!z43uG--cUSx zq>W5lc$iKVbdi`bDd19gQ?!5kMhKTm1^HAxOd6{QT9?((=X@r#ESE@xf<|+&%O{+<6MknJF(Y|RD zGj%~CjCVdnyF)Z#$CwOQadHmSTpx!=UMtex#$V{?$(}4GJ5%IqGg*)$kE@{!5()8 zmj`$g-`5s&_{%9&!_$MZyo2CQnIY*6PGV}hiy=0R{CU$PQltDD)65&`U+o zjKcQE>!7`G581ju3GFY1PLkEyZW+o6tKwXiZ6=DzK54NAQ)69-N9YBz%rO@? z`ez=wuBZa5e;bbH=2XHDQ$66uq$O0ylDAzr+zFDJvv6GNG*Hk#M|?urIQsZaI`WSn z=`M}6bXM6{@MDbt4tD3^h-3@!;SM2Ll_l_slLQ(|_7K?@X7E!hi^PcX8LiNCu$Z3@ zPjZZ4YIG4Z+;S(GZ5@P{Oaf35j?f3~opej}Gt%`#7ZcKbk9b{lg(2P(u|fSIl`Rjm zJ92O~89#|ox8gKp4O4`=;RU4m_dvWI?TKGs9D-_J%HqrQd9ZV|0^pi?^e1^~;NGjC zy=DqHKPiJ=b*cjE`IKszdf~0!N~pH#rD4aU>LRECIjF4wv-m=IwJZ}{o;YCOkxk-F z$wIU%u*Z&sRA~G58GXH*h2_uqz~Er4SQ>}FZK;Cgm8s0^(wVg2$RZ+>uw68?v6ytc zxJvhj^O@6~3-R%EN-qYKVU#+Ld^?OsO{BC)at8-C*iL0mE!aki;#ZP^$e!$G6d1l3 zo&7V&sqg|A>9`q9q^j`jGXo;HEC6k`Gp@J_R26cvfWO`xfj6dH#131SP^cxQHHzv)0O(Hv0P@`f%t;RxP=Y#e*diOlMp02v*6B{K73jmy z8-5iH(usrI=M!OfS`LnOUxD{uW#Q8auyHTz^E)gXTQS+XyH1fsLqznk z?@eYSxh}fkJdbFE*g#g#YItVti9g;K(26Qc{9$Y+nU_(4W9k#&w{;2lNNz3}?CFM< z;SBkry#P$UP(~%Y4DefD1xocfq)rq~R^NS7+d4@f4&1ZH(&k1{#giIpxz$6wlN$`q z&wXH_n>ngEHqi;esp!JtVz`bnd|Pmao=CDqfk`-YMwCJIXe%_$wkE2x=fTRaGAMKI zpvYp)0g>LDH6pvh5g_!;#umpR*l_opNG zqlrp8`(e`yktoAk8VpaDQgyQ-&~xr3nIPDRMZ>DFEc*&IJXUC@tye*%7T9CdwrjL! zXDX_9m!aw6ba21cLtK5$;JY9KRYT)x+q&=R4<@xlWepeda2kC7t2?B3-JpY;`Cxob zUi|X&h4|&|GoN+c_R3SE5p{6n1m)Uyknf^j(=3wCG-`Rl@njBUtdGZucQ#<-l82!|#K5;&1kjA|?6>A}Ek_@akRre`IB>vLabkF$2&mol{?+r5GIgPJdr z!x_dnG0P75R$3UVY{016IHRNeNZfT`0^BKGMZ=?0p{g|l>XdcFytC>-y1^IoKo<{u2M{KO`+58c;j)sc_boP3b;)U zcqR1>*>4>S^}j9#OC@J~+&c3L`~N&EyMi=izv7`LAiBCORV z!=8#?H1~}m*H$#qs+JTmO)vqjdI|n|a}8$JHO`}OEWo`n1JyF_ld{wMiKp6P z=BdX5oYqo6$9~UYY-ID`)qWYfr!y%HUl#}29U0`AViHszzgTzKx`;$abyI@`Su8y{ zlls~RV(rUNj2`6+QyMJb_MJ>Hj*2D*4wWcq_lGpEGGw1!4CABoaEq!LT-ohLkB-X6 z?2RKp%{drthcIYap;~v+D-)eGra{yJL)hRj1>@T9lVyh|NG^oc0lI@JsNX88)7H=QDbEvyqlhmbq^gdY^*HZy;ck78J6Q5 zR;qpTEqD0o&>d3hXo+bxfx!RLfJBGtz{$hzxVf+#ty3P;?_8T{fnP43_DO}AKtW>s{?=xRFejWT1f zaz6ahGaW{2t6|DA4eC$J=(fEK&Ny?AZgyLV!)g{GwO0l8QA^-(d?h}4xu1Gd3)N-NM;c?0SBITL@hQ^(%SG*W#!l$k3U zh2<-5iafp^Tz55SC070Ml(b~2VEn{%8urxnI6u^TC=ENcJc##jhX81E;M`6p=O$JFm__r&TJDSgwri{Dru()fKp2j~HjgZt`i>`F$dwYk<04p$~{W_B3VN zrzVd;3(XWz_EV!otB~q^Cj-+bNWpaO4)X0FMGShrg|eQE1l1>YqB$;O@jz@DbQZdy zSEB=7h_S=*r6pv-{$%uS^MH{CahT;UjdB-ui^49g!38m9fSi`PlhXQRId3iMG&#__ z;up*9;%!om=GQoR*j^%+afD=qED}9$xI+INsz;3rS#K-afCo};aYtuvgv}Oy;B28N zmTxt^VPA?CQ)=rHt_slSR6Je~Ws4$~JCT3)@kvTbsV*5CcDT@84{woT&j*x$;W=4a zGZ)&`$Dv_>yZ!w(S%_l)!We6)3~_pY>ft8$9{$5-3pY8|U+sS520YN^aM&txERF(u zklY~g#qPWOOg?io`H``DpK+iud%$mdAKAVI7TbPuIKvsnRu?;<@|OdBPRPE+9nT)% t_+I;pZHPXkq0{{q~F=b($?0}nseWQab~lvhDHfR;vk`z7bOu#OT;{DeU+xUKFVKL z&9xcD*q{gr-&|RP<)t7eH{Us*rN9hjv;@M2U15UIP`;2S<0R&bn0C0bRjph?Lz{&t zG(7g{B>Q1@v6$A2j9QRbEaXK9Eobrt)563|*RRjPf*=vw*y^80EAOCqe#}UjT|p6o zSwUh!l(1PqkMB_RVYl5Wvrr)@P9V~)|N3F&v*gV+86|$SNDvVv)Y*BupJP6Y#mr!S zNCQTrAp{(bPmtJ062Y?_Y|H+llhq6fKPKMlpZO-Yvl|s|wb-Zg7a}I|e-vrY*89RW zoZnWU#H9&}wlYLdAgG6GU%rQI>KVmAQ-uT$aq;Z0+`^7zed{J0XG@s|##FMQN9m5yN zm^N)@v%k(X@7M3rLUJFr3<6?-2R|sZY1UkWQ4vZad110ik*<6L*mM{*iBJ$Kh~bAc zXdD!w+kioiYDiS1Wuzb^CQ1}FRm=o29jf*X=Irckz5+l7vo`Vg2LR1WP0<*~sD1%& zL(baSTYq(bGikFox_$9e!&0jZS8X-_CDkvtN%anP|4Rulw`y>Pi1!19fGsEcsyC`; z)1(y;~U5D*c<))$NF+DPokx z{5u~;tqJpu{$n&+fd0!&i+r(Z$`^Lemh!)5#K30#T43Gg$FTmb7e&m!$oSTAj@Im6 zEwyR~n)=<^CunloGZyo2JeY5x^L=Et${UBl?7mi!5z(@f?0ZL53xwJ>yR>^NK=Nd% z zx;(12UieZUeg1)4GY*)J7_hZvh;MWC((l|T#jcSZn~4k(2zk?Df}+FBL7CBDxqfrA z@-l0<$^FPM=Qi$dbrWmY{bL{V9i5b!Mh0yzMPAD~QL`S!aP%bQIXxv?V?NOglFBX7 z^M&JyY)D%Yk9H4NVafF8WY@s~@GxgCjfhExf{HXyRxm|jO+I`WEDsyUJ*Q`GErs8Y z_XRE2B)q#T2a`OKsOQ7cc)agK5}7Cmyi#Ypnf#d?aEyf$m!`o8Um=c{jKi8yF)&I) z2PgDjjh-JZ$<8e&iKA&Um@F*7V6)dGCc2iIBq(CenNif=DHw}B2r$ObABLQ=fxjN6 zgL!leF>%hv`t7^Ot7p#KGioZtv~V@bRhuGvJD^)aUn;qi39)to)Lh*U3rAikwo;IT zqMYR*=}fSGq&MVDos#kiUZ;3NrlW?%(XxYo6^>m#AhMRsSpH3{M zDpS_N$)we&H7A}P3eJR{bqw*$m=7K`{#<{PAt3Ib%gv8kLMCtYf%ltJQFZiOYM`Tl z=4(>u*kzMJOh!}}@H-HY^aqUkyu$4Yi8d*BV{4?iS1F1DDm zIT$ur70?mAHqeQm)ww2fCcz(dp3qBQ3zKH*(8)BHZrjVn5f>iQ4d(>h@lr!vH~YTS zXF})VTOo6>;N2Tio}q#9cBwS{=etChl?weoa!JIpa(X~>5jpY32sb@RrRGIhm>9PJ zlp>RG$@?LgX#Irr&sj?b97x2`l|Eo#5{DVyolyDOZfW?R3vr6X3Xol1e5R8z={Q=B zvS{sb(=Q&TP$V6^H`qzO>^eD;W{!3l927d}VXV3dS96dXx;Po&yw3A!mU%w(N=t&$ zmoYHI)`GZLa-hc`T~vJNkJprL6YsrY1Qzun-QH9Y^ZmWa?b)TYpgajI6D)wIy%sOs zUx?`^O331YW;ncj1?~4Mn>$D`3qBrD;Jg_{X~g0<$b6bc-l{Hu4act+-*C((wt1Sx zXMEGqRcAOvA2fwk&OIs>7WFXFa>PM2dWYPVT=Mv#C6A~jZf-^_FasA45=$Q19 zZuKamD*~3`Ilp8Wway(!3Zrqo*BtD&c`8yTP0%)+4oBki@%4uT)VDSRIvd`keO~fF zecO1VeDo(YxN(>sdK8N1%9GH?f(?aZj*%cQKiqs0$sV;CIN_NJi0DF~UN`A8qJnj$ znNG=xJp@$`lwXUmE9+zk++`=QcI!ux;D9~!d!zP^cOZ_lASmi@*(^2v*ev*u!IQUQ#8nFwBN0Zmg} zh}9Rx_A&HjTsEYk-bO9nz&?@#33#(A)#&JeFUF9WÆZO zm2gXI9QU$QAOu|J;f1@d7@;|p_}f~6arioX?EEJwwTQKM>;Ld zg5GCTK>BDkB$qj&&Fuo19^wxwS0(sXQbw!%0K689VZPV^7U|rflMXI`H1i}l^rjy; zI824#YH~ogWDk9kHGpg^D*%UNYZzH>2y<31!`NksQ046j3gKGx?@9@%RL0W*obDi6 zo&lZSdEtlw!C3fl622*!2FJFSQ%Q&uW~e=*iBH~hFQypciwCnH&uuyRWa<$p$ zKDLP7(>+Z9tI3DazhLY{OW5oDiJmyH9v+4m;Y8#9WN1hz7~6Y*#BLqk5wRH71ude| zm-(Q0IF0z&bKu0JB&<#}g8d$SAad`1TK>75NFQ(GmZCrGUegzjIM2low>4?sk7n@Z zPB4bE9$y@I@Vquk6c)bf6=ixX2wP0!mc3hA3M?il&641eKD+J3dw5$;3k#F zG<`)fj5SJu->(jUy*Kjk{J?nFeZ4E4on(nlhbqa9$;Ox&o z;Mu7juxm#O-dRvWI0Jj)kv~=fb1N1zc2hEVsUDbYn}G9lyTejTJFv31!V!rIaIyX& zcjK_#)aHQ^>~f3X?!U4CZ>Y=iZn`51b4VLgzeV=4@=oz!#~>}B@Ih07gd7YL=ttJT>hct!2SJW3v#GW8I05;2N3fvJBTuPCt4_RRiW- z>5iAj<-@+AK5*B5Iu)~21_yR`g#~39_*2Dj=xThK_z4*7b^Jaxcz2ra(pf~uw&q z-ATqehTxwTL8uBx>9eY8TDajYdA08qH>~zC@xAT=N~7)Yl=gF~xGtP?^w3z+-=0vf z)hWp8rV7Q~SCF!M!T49qNStt83AKhP;)f+!u(MZJz=ad(d6iDUdz?q>t-9i*1!?q_ zYXQj0A=R_=#Rv8Iu)(38hW{p497#>cA^ol}Mkt1lGtGhyojMnCpE#3IkLQan)K4rtI!9%%cr)(qH@pDK!!FIIuG7FBFhHjJ7$k~V$ zVBoSIE#wOD;#(6Ux*-C6#tj!{Ws;=Akd8Ck(M}d@_wmCZw!S|pLKdDnY>z1kDX=lw z1l*q)!j10v@Z0WXoL<{3@%Bl zr+1k}M6SF5tMbFqZ18sSllKak_jwfA9G;Dr!;I+BDd?i)lhOd z1k%F%A$-OH(nRsRLJy`q7k%r$m1Ca*X814m-lRZ@-YK@U9TZaJjO%4WG$AY_9uE5R52>>7wEcSK3+=LC=Cs-rgK!o z@zTu%teujI=eDlFX%*J^s>GJY8V*4_C1vtVp^R)3Xws)g)ZhR|0)2v~L*cy-Q+e(M~OpUfsJ2t_?atmVC0ror%`uz@R@!_T4p@qoIN&C&IYx1$&^ES@*bHvfy34 z^j*3U7^eAPX>cBxea^?Zdbv1HWi8ckNrW|fSsfp`jt;V5_t|xpOjVN7m;U#;Ysekx zHMfaGCu|U8)XjspjwA8(6A{fTu*H4-(#gcMJnVfU0j?}gz!%EnN#~JXXdA&LJ*!rL zMNf59M`mKVixQrCa-VyW>xUsrGq|U=JHehF1EJruHK>0ooV?R@A;;i}G^5HF zD249m8Hh-Zc@W<1k`5yda)4H9>s9l?j zmeW$f`*t1i@VA1kAp{D##M8>fztVjcMMPsE53+DL{Cde7QeWMroy&w^zLKTfQ}(Lc zqLbC!`SprBHfjy5w+R3@8&%nB#|C=WDFb8*E;2#RUbB*Ebo#&1!cj)q;4gVsc@0fhO{R7CUW4^tV06~F literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5798ee37ce128f3bcd50cdab6433fe7543da1e9d GIT binary patch literal 7207 zcmbtZ30%x;_is|vPwP}va)nB_XxA{!b5f*5iLR|Iqh)ARQ<|BGLbTCB$Wm@f_Wjz` zJjbt}JmwdZ*k{D%b!0&wPG;KELOj=RD_pzvrCinI{z$wtl2YG&@=< zk_aONR+8vIk$75Ubc7&UY}nb>%FfQphFis9oY`#ap^+l7Bv2&bM@l78QVHM2P_=oi zk4lr&LYGmF2@IDCEL5~uUP=lI3!L*=O3YA3M<{BT6)Fq~5s3J5OA>*YX~!yC-P$E2 z33Xtp9nn_702{M30om z6&Nm@6DSczioPt+;~QFi*llOZBUB2UEfnk5H+@+7EJX`#ri&m-EDVnn>Fqq-kCV@0 zF*BL(@_^B5pn%Ku36%Iq!})fD?bx4vvYsgwM8{hHGu{k4yV2oR!akipD>0K8_0Pr{ zFti(N%Y4fcTfe2HR;X)!CHYK1#QzwO1KZ%Uy)Z%BaJ0S_H3##*;>T&iuQ3rj!mp*U zdC9Nvb8Nz|rE+uXG~xGI*dBh(lF#^MDl#2H&U6e>)1+o=%^olcjbHVsK(T}`2bV7v zgim8M8($l>8wZ;F^b?3f1EU0dPV;=WR#db=ERYv<(_}WgY2?d%!>%q8iCn3nY;5|l zcQk-QMvaQaUYeGxS?I<93!*gM!X;rvDQ3mV-%r(%Ps z6`Sg<#y@j;xlJ?bVD`T>DGTccFvR@tXjNt(bx2dTl?8|jjtKu*sy!G47{D3 zH|*BFLNgoBSj@lKV7|`DZ>M}KdL0L|`&vhYN69zwZ|%}8Qs3?iq}^L(I$zGXt9;-8 zf0clU`5t%L0Bz4UY>1-av9f8?owVD?n&FLPzA>>CQXD&Z18iEu*1F-S(4q@yxC_Fi zEf7#?K%!yn>m&-1H^7dpvfnhc5Yn7zR?_;i3k*m_|u~Uqe3k}h0$WUes!qyGH*D< ze$QjUYnl>qO0a6vNSzl<)MEY>oa%Ge|10#K0GhC$b@JNi5nw zT!kewo|9b%2f)Lec{Dsa844=WKt;(6MK$^Gey}2J9QT}_xwRC2JKh&`T$Awbt{hDA zNTQw(N8|Cn6G=p(0`N78hWu+`4Wg=GJ8iB6!v|zzHO{@}4#rO&j;QRi?b5pM+zvaZCCw~XoqNt9+HHDP* z$_TVyab@G&`{BWuT&P~@g}x=ucumU1j@fdMv-NbH5{8+5A|FpIrm9oc!pWr7s1qMc z4+Uky4|NRj%vb;(HU2z*(;*<~pUcaSTtX&q^nrJqQ&D~NJZhw;gcfU3=-6eGK|)7? zoraq{cY5N!H-32U@LF(~yof&X%)&tq%i$H74wb?X+AcrIQ&3zr8B%<6k-0b>`bTGBp|&;L-0ek=^<0jbYkGi= zTL=#B!ox?YsaX8n8N>Q1)7^^5>4tMc-gsG0TsQZ=%x6O9;#5O_S0P=%1VX)A9y5uSvftRy@;H6ZH$|qq*9BbEKHo809_)IaLKzNm}v8a z^v_vK1{_Gl(Um@6WI7u&ygQ-FwcWC?KNjK?sWl+Gy!cEf6VhI3I(nIF6er@T9qx43 z^9N-0$Y)e^t%l6pJRYiae?+rDZ>J}f$`Hl;&g+suE{qb6t z+r)csD1k+NNVnHj#A1JMa(iwmEhtX{t2j&G>#oI1_ZMRNi4wASpg9gNUqSo*!sZQ9 z&VmmIl(?@)Q5wE@He^0cBX86bV8ik2#Wy(F#4b;}_>6Bly6O#wsDoy(%6SORu6jad z9T|x8pZbyHEm?H`k z;5k46<_H#{aj7$XG#ANs&Lp~7e*zBdwwI`+9VKa@(_}TLZqmQI_N8sfmj7z-GC#F% z2~WM+9e+8K$TJ((0|N4mVNUG|9DQ~=&D|J8aB-ZB{< z$F8J4sugr)_laoi?I+L3rBHlbm)z9x1L!V>~AREgHz%khd zMwa)4_|?lWW?3Rsd3%CVm=68BQVOb-v2*~pJBXKOK&Q7}IATB$7Jit7uS=%GvF+tl z8tjA_8qa9rlXtv}DLwJ82Xi3!;!!f+Z8`bqHJaLWnFEuj_JD)#bD*oLBm75}4egO_ z?`2+GPyl(o^Wble1pB%y!UVAtq4I3eX1z_c)0?00hUa5C<*?)%WH+?&($@lX+n9@~ z!;?W;n}?w$8gOxS2-cLQQ-#`WbRSzp@9Cc=fYs#v=$|oWq803Q{zy+8SPu__jd7w$ ze=;;U1WX(}Kx)5^?g(ED>jD>1*^7KoKAc8;9Jp{|QW91t8pD2%J`k~YKP~@MPGpa_ z@k-GjcCYCRN1W&3``g;I?+0^ueJ2P*b~=MN@+Ea$J_Bnnk4GM?#Xsp-95eGGabZ^w z_Ya-uf67xps<{|eF@@xn5O9;~W17Aq8O9o?z;9Ovz}_2qcz$3k?7pr>=O$U9)1gXo zW3mZmhGmd*sVp?&rtv~&IuKvaC3tq42khFBf_D;12zTHQc;xq$z}$+#jNOzBUTOfQ z+a}5iAj<-@+AK5*Az29>Z>2T$$p3JGNy_+!OzP&2tq z{DchlI)0xTy**8L=`EsTa({x;o5gUbHVb6tR)EtIJXY;z-?IdG4!T5(|AgaSr`mCy&7H)V$UhaFz3$1-je6M>z zm(likO7}TcUKhqadT1=^??9;6>J(&kQ-|X2D@fVBApA3WBu=>A1$BlgSqv(pG z(}+^sc3Ibw)#T~BTl7G03(<6XsJ+u7vGo?@eMH;GH$pqE0ZJ@21AE!oMSI1 z+kN~nh^g;Sijak;4m)5f?8eE4OisY!1uDRd;hy zkf(q;{6zFyyHggrDUql4b~Th-4u-T)e+Zkoj+|IN9BTgx0JQ)C=`2dXzxeCHIo3dC za*+onM;I*M77XjTnNViz3o6H!f=1ydI_;zjj1FS3-)&birg{MAdFsJc<9;-y;we#_ zyBuQ#!{JS}H@&^W1b$m}Nv3E#8GBaC(IezAUk$F+GQ%eHRt-OsR}) z6Kd0^M>OC7R|>JpDvW^9t;6U%zxRF)a^! zpNNAii{tPwmGPwWNH4St=aC<(R)FOXnyA4|gMcLkpmAa;DV9Z(`HwyoRXCWy!N*RR zQ&u9&d$pO`$xrgq+Rors;|Ei{Y*5Fglnw|mmOkwG6G0rv;lN^gcyxTP$MjYf~FaKG%lRFje z)g(gvmG1QE`AqP=bX&$7QA#_x1z_nrnJjH|ConsgLv^jYK;7l{WPo@LW_K&V+{~NQ z?AMiC!@hY`VTuz*ZM#kDb|#~4Z7y0(PX+JWb;QHp8ny-#DCinXD;NJl_gNMZt%ZEZ z!r}1CC2vT5d6#xB6G6a_?fx@~&1SSzSZp;$QAJVyu;dRD(2r?-H_qeyY literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9162f2f20a2ce849775f414fcd08521618f4ec0e GIT binary patch literal 4524 zcma)A3tY_E7cYhS(QA?owVPg;C?#6c@1E33HNBJ-rHtvJQcY=QT1Awa9(t&h^g!jY zi}m>bh1A?*Jr+xAvsiiU+SFfplm5Gf^`ElXy#9Xkxig=8&OPUz^ZS19x%ZBm8b?1? zBw7+L6-gpv!Yn26LXmi3Yu#*71pWA#O>NGuVGB!XC}Bu*+3 zOdP5HN@lF~Qp+923``J4OT)No+RA)YC8ZRHDrHsX3r06m)N2(H85$ZU5-2=L!oXB+Ss4{oqKUm}6@7ch|~8 zm>3rCrN~Pd9hoGQM8=9py?_C|4;D;Oa1g!{Wq|(Xk@E zZ@(X7SEa1XY*S`FDgKtK`JM8V>y1i|0w40}`Fn--(Q(eDe{$~R0LA&d(X zOnRNnkbhVhtH0`SwHbAhG)52+CJYtp4^YI*(+G);v5bifiH{Y>E|f4r=DkiC?OuxT zc%DLur!-n%ZEMXL@ge{P6RiH0z9NP8n=0c$FNQFFjK&K`z4%(&PGsKig){uMmws%} z=qG?hb|^oG0|}fGJS|G93{m^tB;}9uiv`Xtfo2 zrxlyyaEJwiHGUn0yF*2d>GSYbx!@P5@Isr}9*YcYPb{XM>NjZdC-czA!xvZmu^JjK z3?r9xy}|979riSgCRw~3v~ayaT~(8zqJAT0U0DEUFKnSD+so;Wg74V6r+g|YT#i|? zYM6aL9r&C`nlo@EJ~$)6M@gE{lRX9Zw&ptN@~PW!PKF6;WSXYCh2zy$&Rc+Cl;bruqT8!Ir4s>F7OUr3WhA`XdGg+&9x zu%oRInuX7Z%IZQGeOw*nH%lR-*$%C*RKubWe^5Us#mCZS+Tjbp&z3-n!~oLtuF}AL zsgT2!!T!!MFlEX@_^zuQ^c#26TZI$I=H_ack}(myS_~n%v;Y$d(xJn{4OF9a>7Q*< zP;X156L_D1xF`=){>w-A3BkDONg#GMhQp!lEmRs}k9nGRX!@-`*fW`ic<*`=gq%4@ zRyY@tr+go3GdKx?7L0;@E=e#%eG2?XUIOm|dlG+aLv=N57`*}hw3DETV;ZK4r3h^& z>bL4`p(>AO;YGK0I{$#=17O=Kz;5acL+8yEm^C8}Z!c-H^l!jtgQ!Z8NEw4Kc!brsK&MA)#Pu>k86Im2^jRHdG4JsQi8v z3_OrSJZ*V!Bv6JA(v4uRt1-mv*-KmgY9aFWZEO?z!>)BEQ0K55d#>nElP6=L^J*}L ze(M0@*oV}qXc2awor5ft;qP=de!2Jzapc?}E>BeGe_Ap@s+A3EnN8$TB;XeHcAC2; z17;g#!l`o;V9&)3_`~Ex*mYqjT`IFg`~7X?Vvs53N9B=|S;}a@%V8rH+Y)cLRd`~d zE9~5niC0q_32*YJSa-S@n9B*6w~Lagt4Dy@wz;@M{}WhkISH(6tZ#wE z6EyjUQSi+EB)zUv3{?%yMESZB{qlwZz63vY{(-YO!}o9 zc6+dxY9WJtjVW06Lj`^R%C=L0J;-_--l{CaKWf%vgikiK&ab25-yLamNfyph>Xtuj z(!=yXcQ8*~O)TsxaEgTk#(3L9v~~s#JeH4!^B+@Fhh-SZErSHdZ-{$ZF}<*ABg{W> zm4+l7qk&h8@U~|W*6y2+!k!H9b}b|KPBxOxb^~Oudm!6s3@Xij_{6=27>_T7aNjaG za#0g2|6Gspr4QH@oz`U5>?Nc)&j%Odck-$+$K+;aS}?<82p&4vMrJIiqhE+066t6e zxD@{;uid^5qLng0J;MZ3$7sNai@NB0E*Dyg0@1W_Jso~G01wAxQCEio+B)wN^=n^A zY=`IK?XBY=#>)gIKF)*8CMA>=$FXO(6wpJ(5!BjzEw(-EBHH&#V8dxW^a{2kS?#lF zkY6d*oV!OIk9orJ`ee9ERpHvP6wnARgZ;}s!(WbMLtM!k%zP-u6TfeOqno~BPkd(# zYU|g-?eL}GX;}m&Wi4{@sU<18R3pDOXDD#bsw?X7r;2^#J*P_k9(6i)R0AA5N(J_} z$v5#&X&%X98%KD<;dBn@!BRNu$}HC1yIbyY#n)B&*B?=C%lc{Q9mYy$pf(?Bx5f~|^OMS?ba!XI0+ zP{U_AHPBN4k2XgcHl=Z5(2$G-UM zfikcSN~1U33h{H>B6vi?p)E2Lx1W$8q#2Xpor*ng?`U#mX%nq(k%8qh3lIz|!=HXz ziMdA_N%rKi=we@iB_)>IxVoARJ`Y5q8&j7QP)fhbX5#e4vkNp7U`=%Njc7UmW)Hfrap-R+ZN|lhWwy zC6IqRhdkCug_^?`8ZO$EkodT6YPM`3mLH!<{q2LXz9$mn4gF!-F>Cm}H5a&X@x;ub z3dJ2kkm+54%$aZ)A76-@bgbagE4&|>28 zvoo7=BZX`>pH0v09t|~C)ztm78anSUE!HeK5KcdHgU?3lqUlzUiW`WwlacFSy0NsBx6i>Fu(9?ngo4!0mgnVDzdKAfS&BZwPjyiDK+T<~r^YFoC zcXV2=4Jnmc*dbbg$v0d<;C+vE*61du%9r8zCoG9BXrZ+_Y2-+!5pKDaMY;8bn7$+x z2FJ)0r>|+4KJgY2)0IH^m*^d$j?bF%?K9Fxp|!pYv;uS~8Bt7)zEK6Y2}o2c?51JHTIll!@h!F>r8@SvEF-i;1;UdqGqFH9Mq5pp jGn5&iHbBvs-WJSd-gy^z*I1_k+B!};O#chIw(|b~-o-?! literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index ceda5a88c3925..3ac723e166084 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -8,14 +8,10 @@ import unittest import onnx +import torch +from parameterized import parameterized from parity_utilities import find_transformers_source -from whisper_model_generator import ( - create_whisper_decoder_attention, - create_whisper_decoder_multihead_attention, - create_whisper_decoder_with_past_multihead_cross_attention, - create_whisper_decoder_with_past_multihead_self_attention, - create_whisper_encoder_attention, -) +from transformers import EncoderDecoderCache if find_transformers_source(): from fusion_options import FusionOptions @@ -27,6 +23,422 @@ from onnxruntime.transformers.optimizer import optimize_model +# Dummy constants smaller than openai/whisper-tiny +class WhisperConfig: + def __init__(self): + # Hugging Face attribute names + self.hidden_size = 10 + self.num_heads = 2 + self.head_dim = self.hidden_size // self.num_heads + self.d_model = self.embed_dim = self.hidden_size + self.encoder_sequence_length = 20 + self.encoder_ffn_dim = 10 + self.decoder_ffn_dim = 10 + + # OpenAI attribute names + self.n_state = self.hidden_size + self.n_head = self.num_heads + self.n_mlp = self.encoder_ffn_dim + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L222 +class WhisperHFAttention(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + + self.embed_dim = config.embed_dim + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.layer_idx = 0 + + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[tuple[torch.Tensor]] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + is_updated = past_key_value is not None + past_key_value = EncoderDecoderCache.from_legacy_cache(past_key_value) + past_key_value.is_updated[self.layer_idx] = is_updated + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous() + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + past_key_value = past_key_value.to_legacy_cache() + return attn_output, past_key_value + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L583 +class WhisperHFEncoderLayer(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.embed_dim = config.d_model + + self.self_attn = WhisperHFAttention() + self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.activation_fn = torch.nn.GELU() + self.fc1 = torch.nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = torch.nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + """ + hidden_states += 1 # Add fake add to help with fusion testing + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + return outputs + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L651 +class WhisperHFDecoderLayer(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.embed_dim = config.d_model + + self.self_attn = WhisperHFAttention() + self.activation_fn = torch.nn.GELU() + + self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperHFAttention() + self.encoder_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.fc1 = torch.nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = torch.nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + cross_attn_layer_head_mask: torch.Tensor | None = None, + past_key_value: tuple[tuple[torch.Tensor]] | None = None, + use_cache: bool | None = True, + cache_position: torch.LongTensor | None = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + """ + hidden_states += 1 # Add fake add to help with fusion testing + batch_size, target_length = attention_mask.shape # Get shape to create 4D attention mask + sequence_length = hidden_states.size(1) # Get shape to create 4D attention mask + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length), + layer_head_mask=layer_head_mask, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 1 of present_key_value tuple + if past_key_value is None: + # Skip if cross-attention has past KV cache inputs since the outputs are identical + present_key_value = (present_key_value, cross_attn_present_key_value) + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# From https://github.com/openai/whisper/blob/dd985ac4b90cafeef8712f2998d62c59c3e62d22/whisper/model.py#L44 +class WhisperOAILinear(torch.nn.Linear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +# From https://github.com/openai/whisper/blob/423492dda7806206abe56bdfe427c1096473a020/whisper/model.py#L62 +class WhisperOAIAttention(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.n_head = config.n_head + self.query = WhisperOAILinear(config.n_state, config.n_state) + self.key = WhisperOAILinear(config.n_state, config.n_state, bias=False) + self.value = WhisperOAILinear(config.n_state, config.n_state) + self.out = WhisperOAILinear(config.n_state, config.n_state) + + def forward( + self, + x: torch.Tensor, + xa: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + kv_cache: tuple[torch.Tensor] | None = None, + ): + q = self.query(x) + present_k, present_v = None, None + + if kv_cache is None or xa is None: + # If xa == None: self-attention without KV cache inputs + # If xa != None: cross-attention without KV cache inputs + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + + if mask is not None and kv_cache is not None: + # Self-attention with KV cache inputs and outputs + past_k = kv_cache[0] + past_k = past_k.transpose(1, 2) + past_k = past_k.reshape(past_k.shape[:2] + (-1,)) + past_v = kv_cache[1] + past_v = past_v.transpose(1, 2) + past_v = past_v.reshape(past_v.shape[:2] + (-1,)) + + present_k = torch.cat([past_k, k], dim=1) + present_v = torch.cat([past_v, v], dim=1) + + present_k = present_k.reshape(present_k.shape[:2] + (-1, self.n_head)).transpose(1, 2) + present_v = present_v.reshape(v.shape[:2] + (-1, self.n_head)).transpose(1, 2) + else: + # Cross-attention with KV cache inputs + past_k = kv_cache[0] + past_k = past_k.transpose(1, 2) + past_k = past_k.reshape(past_k.shape[:2] + (-1,)) + past_v = kv_cache[1] + past_v = past_v.transpose(1, 2) + past_v = past_v.reshape(past_v.shape[:2] + (-1,)) + k = past_k + v = past_v + + n_batch, n_ctx, n_state = q.shape + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + wv, qk = self.qkv_attention(q, k, v, mask, n_ctx, n_state) + o = self.out(wv) + + if mask is None and kv_cache is not None: + # Cross-attention with KV cache inputs + return o, None, None + + if mask is not None and kv_cache is not None: + # Self-attention with KV cache inputs and outputs + return o, present_k, present_v + + return o, k, v + + def qkv_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor | None, + n_ctx: int, + n_state: int, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + scale = (n_state // self.n_head) ** -0.25 + + qk = (q * scale) @ (k * scale) + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + + w = torch.nn.functional.softmax(qk, dim=-1) + out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + qk = qk.detach() + + return out, qk + + +# From https://github.com/openai/whisper/blob/dd985ac4b90cafeef8712f2998d62c59c3e62d22/whisper/model.py#L142 +class WhisperOAIResidualAttentionBlock(torch.nn.Module): + def __init__(self, cross_attention: bool = False): + super().__init__() + config = WhisperConfig() + + self.attn = WhisperOAIAttention() + self.attn_ln = torch.nn.LayerNorm(config.n_state) + + self.cross_attn = WhisperOAIAttention() if cross_attention else None + self.cross_attn_ln = torch.nn.LayerNorm(config.n_state) if cross_attention else None + + self.mlp = torch.nn.Sequential( + WhisperOAILinear(config.n_state, config.n_mlp), + torch.nn.GELU(), + WhisperOAILinear(config.n_mlp, config.n_state), + ) + self.mlp_ln = torch.nn.LayerNorm(config.n_state) + + def forward( + self, + x: torch.Tensor, + xa: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + kv_cache: tuple[torch.Tensor] | None = None, + ): + x += 1 # Add fake add to help with fusion testing + + self_attn_output, self_k, self_v = self.attn( + self.attn_ln(x), mask=mask, kv_cache=(kv_cache[:2] if kv_cache is not None else kv_cache) + ) + x = x + self_attn_output + if self.cross_attn: + cross_attn_output, cross_k, cross_v = self.cross_attn( + self.cross_attn_ln(x), xa, kv_cache=(kv_cache[2:] if kv_cache is not None else kv_cache) + ) + x = x + cross_attn_output + else: + self_k = self_v = cross_k = cross_v = None # Set to none when creating encoder model's attention block + x = x + self.mlp(self.mlp_ln(x)) + return x, (self_k, self_v, cross_k, cross_v) + + class TestFusion(unittest.TestCase): def verify_fusion(self, optimized_model, expected_model_filename): optimized_model.topological_sort(is_deterministic=True) @@ -50,144 +462,457 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) ) - # Attention type #1 in fusion_bart_attention.py - def test_encoder_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_encoder_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_encoder_attention_sln.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx") - - # Attention type #2 in fusion_bart_attention.py - def test_decoder_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_attention_sln.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx") - - # Attention type #4 in fusion_bart_attention.py - def test_decoder_multihead_attention_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + def export(self, model, inputs, input_names, output_names, dynamic_axes): + torch.onnx.export( + model, + args=inputs, + f=os.path.join(os.path.dirname(__file__), "export.onnx"), + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=17, + do_constant_folding=True, + verbose=False, + ) + + def setUp(self): + # Reset the seed to 0 so that the tensor weights stay the same for each test case + # whether FP16 or FP32 is tested in a CI + torch.manual_seed(0) + + self.config = WhisperConfig() + self.optimization_options = FusionOptions("bart") + self.optimization_options.use_multi_head_attention = True + + self.batch_size = 2 + self.sequence_length = 10 + + def postSetUp(self, precision, split_bias=False): # noqa: N802 + use_fp16 = precision == "fp16" + self.device = torch.device("cuda" if use_fp16 else "cpu") + self.torch_dtype = torch.float16 if use_fp16 else torch.float32 + self.optimization_options.disable_multi_head_attention_bias = split_bias + + def tearDown(self): + path = os.path.join(os.path.dirname(__file__), "export.onnx") + if os.path.exists(path): + os.remove(path) + + @parameterized.expand( + [ + ("fp16", "cuda"), + ("fp32", "cpu"), + ] + ) + def test_hf_whisper_encoder_self_attention(self, precision, ep): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision) + model = WhisperHFEncoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + inputs = (hidden_states,) + self.export( + model, inputs, input_names=["input_hidden_states"], output_names=["output_hidden_states"], dynamic_axes={} + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_mha_fused.onnx") - - # Attention type #3 in fusion_bart_attention.py - def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_self_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_encoder_self_attention.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_hf_whisper_decoder_no_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperHFDecoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + attention_mask = torch.ones(self.batch_size, self.sequence_length, device=self.device, dtype=self.torch_dtype) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + inputs = ( + hidden_states, + attention_mask, + encoder_hidden_states, + ) + self.export( + model, + inputs, + input_names=["input_hidden_states", "attention_mask", "encoder_hidden_states"], + output_names=[ + "output_hidden_states", + "present_key_self", + "present_value_self", + "present_key_cross", + "present_value_cross", + ], + dynamic_axes={}, + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx") - - # Attention type #5 in fusion_bart_attention.py - def test_decoder_with_past_multihead_cross_attention_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_cross_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_decoder_attention_no_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_hf_whisper_decoder_with_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperHFDecoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, 1, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + attention_mask = torch.ones( + self.batch_size, self.sequence_length + 1, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + + # past_key_values is of shape (num_layers) where each element is of shape (4) + # + # Ex: + # past_key_values = (layer_0_tuple, layer_1_tuple,) + # layer_0_tuple = (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0,) + # layer_1_tuple = (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1,) + past_key_values = ( + ( + past_key_self, + past_value_self, + past_key_cross, + past_value_cross, + ), + ) + + inputs = ( + hidden_states, + attention_mask, + encoder_hidden_states, + None, + None, + None, + past_key_values, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "attention_mask", + "encoder_hidden_states", + "past_key_self", + "past_value_self", + "past_key_cross", + "past_value_cross", + ], + output_names=["output_hidden_states", "present_key_self", "present_value_self"], + dynamic_axes={}, + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx") - - # Attention type #4 in fusion_bart_attention.py - def test_decoder_multihead_attention_split_bias_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_decoder_attention_with_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda"), + ("fp32", "cpu"), + ] + ) + def test_oai_whisper_encoder_self_attention(self, precision, ep): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision) + model = WhisperOAIResidualAttentionBlock().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + inputs = (hidden_states,) + self.export( + model, inputs, input_names=["input_hidden_states"], output_names=["output_hidden_states"], dynamic_axes={} + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx") + name = f"oai_{precision}_encoder_self_attention.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) - # Attention type #3 in fusion_bart_attention.py - def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_oai_whisper_decoder_no_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperOAIResidualAttentionBlock(cross_attention=True).to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + attention_mask = torch.ones( + self.sequence_length, self.sequence_length, device=self.device, dtype=self.torch_dtype + ) + inputs = ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "encoder_hidden_states", + "attention_mask", + ], + output_names=[ + "output_hidden_states", + "present_key_self", + "present_value_self", + "present_key_cross", + "present_value_cross", + ], + dynamic_axes={}, ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_self_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx") - - # Attention type #5 in fusion_bart_attention.py - def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_cross_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"oai_{precision}_decoder_attention_no_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_oai_whisper_decoder_with_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperOAIResidualAttentionBlock(cross_attention=True).to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, 1, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + attention_mask = torch.ones(1, 1, device=self.device, dtype=self.torch_dtype) + past_key_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + + # past_key_values is of shape (num_layers) where each element is a past key/value + # + # Ex: + # past_key_values = (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0,) + past_key_values = ( + past_key_self, + past_value_self, + past_key_cross, + past_value_cross, + ) + + inputs = ( + hidden_states, + encoder_hidden_states, + attention_mask, + past_key_values, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "encoder_hidden_states", + "attention_mask", + "past_key_self", + "past_value_self", + "past_key_cross", + "past_value_cross", + ], + output_names=["output_hidden_states", "present_key_self", "present_value_self"], + dynamic_axes={}, + ) + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_split_bias_fused.onnx") + name = f"oai_{precision}_decoder_attention_with_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/whisper_model_generator.py b/onnxruntime/test/python/transformers/whisper_model_generator.py deleted file mode 100644 index 5527df489b846..0000000000000 --- a/onnxruntime/test/python/transformers/whisper_model_generator.py +++ /dev/null @@ -1,2021 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - - -import numpy as np -import onnx -from bert_model_generator import float_tensor -from onnx import TensorProto, helper, numpy_helper - - -# Adapted from bert_model_generator.py -def get_tensor_and_weight(name: str, shape: list[int], random=False, zeros=False): - low = 0.0 - high = 1.0 - total_elements = 1 - for x in shape: - total_elements *= x - weights = ( - [np.random.uniform(low, high) for _ in range(total_elements)] - if random - else [0.0] * total_elements - if zeros - else [1.0] * total_elements - ) - return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights - - -def create_whisper_encoder_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - nodes = [] - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.append( - helper.make_node( - "Attention", - ["layernorm_output_to_matmul", "Attention_0_qkv_weight", "Attention_0_qkv_bias", ""], - ["attn_output"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "k_weight"], ["k_matmul_output"], "k_path_matmul" - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["k_4d_bnsh", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "v_weight"], ["v_matmul_output"], "v_path_matmul" - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["v_4d_bnsh", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], ["bsd_format"], axis=0 - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "Attention_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - if fused: - initializers.extend([qkv_weight, qkv_bias]) - else: - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - if add_k: - initializers.extend([q_weight, q_bias, k_weight, k_bias, v_weight, v_bias]) - else: - initializers.extend([q_weight, q_bias, k_weight, v_weight, v_bias]) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_encoder_attention_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - # Dummy inputs are used to prevent the nodes in the path for the decoder attention mask to be fused together - # before attention is fused - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - if not fused: - inputs.extend( - [ - helper.make_tensor_value_info("dummy_input_int64", TensorProto.INT64, ["dummy_input_1d_int64"]), - helper.make_tensor_value_info("dummy_input_fp32", TensorProto.FLOAT, ["dummy_input_1d_fp32"]), - ] - ) - outputs = [ - helper.make_tensor_value_info( - "present.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info( - "present.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - nodes = [] - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.extend( - [ - helper.make_node( - "Attention", - [ - "layernorm_output_to_matmul", - "Attention_0_qkv_weight", - "Attention_0_qkv_bias", - "", - ], - ["attn_output", "present_0_decoder"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - unidirectional=1, - ), - helper.make_node( - "Gather", - ["present_0_decoder", "index_0"], - ["present.0.decoder.key"], - "Gather_0", - axis=0, - ), - helper.make_node( - "Gather", - ["present_0_decoder", "index_1"], - ["present.0.decoder.value"], - "Gather_1", - axis=0, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "k_weight"], ["k_matmul_output"], "k_path_matmul" - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node( - "Transpose", - ["k_4d_bsnh"], - ["present.0.decoder.key"], - "k_transpose_to_bnsh", - perm=[0, 2, 1, 3], - ), - helper.make_node( - "Reshape", - ["present.0.decoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "v_weight"], ["v_matmul_output"], "v_path_matmul" - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node( - "Transpose", ["v_4d_bsnh"], ["present.0.decoder.value"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.decoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with mask - nodes.extend( - [ - helper.make_node( - "Shape", ["k_output_(num_heads*batch_size,seq_len,head_size)"], ["mask_shape_output"], "mask_shape" - ), - helper.make_node( - "Gather", ["mask_shape_output", "idx_1"], ["mask_gather_1_output"], "mask_gather_1", axis=0 - ), - helper.make_node( - "Unsqueeze", - ["mask_gather_1_output", "unsqueeze_axes_input"], - ["mask_unsqueeze_1_output"], - "mask_unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "mask_unsqueeze_1_output"], - ["mask_concat_output"], - "mask_concat", - axis=0, - ), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_mask_heads_output"], "mul_mask_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_mask_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_mask_heads_output"], - "unsqueeze_mask_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_mask_heads_output", "unsqueeze_1_output", "mask_unsqueeze_1_output"], - ["concat_input_for_reshape_after_add"], - "concat_for_reshape_after_add", - axis=0, - ), - ] - ) - - # Create nodes used with Q x K' + mask and softmax(Q x K' + mask) x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], ["bsd_format"], axis=0 - ), - ] - ) - - # Create nodes for computing softmax(Q x K' + mask) x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Reshape", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)", "mask_concat_output"], - ["qk_output_(batch_size,num_heads,seq_len,seq_len)"], - "reshape_qk_to_bnsh", - ), - helper.make_node( - "Add", - ["qk_output_(batch_size,num_heads,seq_len,seq_len)", "attention_add_qk"], - ["add_qk_output_(batch_size,num_heads_seq_len,seq_len)"], - "add_qk", - ), - helper.make_node( - "Reshape", - ["add_qk_output_(batch_size,num_heads_seq_len,seq_len)", "concat_input_for_reshape_after_add"], - ["add_qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "reshape_add_qk_before_softmax", - ), - helper.make_node( - "Softmax", - ["add_qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create nodes that make attention mask - if not fused: - nodes.extend( - [ - # "attention_mask" is (decoder_seq_len, decoder_seq_len) but is assumed to be (1, 1) for this test. - # There are other nodes that automatically set the attention mask size correctly but those nodes do not - # impact the attention fusion. Hence, this assumption is made in order to simplify the inputs for the - # following nodes. - helper.make_node( - "Where", - ["all_ones", "where_filter_constant", "dummy_input_fp32"], - ["where_output"], - "mask_filter_where", - ), - helper.make_node( - "Unsqueeze", - ["where_output", "dummy_input_int64"], - ["unsqueeze_mask_output_1"], - "unsqueeze_attn_mask_1", - ), - helper.make_node( - "Unsqueeze", - ["unsqueeze_mask_output_1", "dummy_input_int64"], - ["unsqueeze_mask_output_2"], - "unsqueeze_attn_mask_2", - ), - helper.make_node( - "Expand", - inputs=["unsqueeze_mask_output_2", "dummy_input_int64"], - outputs=["attention_add_qk"], - name="expand_mask_from_(b,1,m,m)_to_(b,n,m,m)", - ), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "Attention_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - if fused: - initializers.extend( - [ - qkv_weight, - qkv_bias, - numpy_helper.from_array(np.array(0, dtype="int64"), name="index_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="index_1"), - ] - ) - else: - initializers.extend( - [ - numpy_helper.from_array(np.array([[1]], dtype=bool), name="all_ones"), - numpy_helper.from_array(np.array([1], dtype="float32"), name="where_filter_constant"), - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - if add_k: - initializers.extend([q_weight, q_bias, k_weight, k_bias, v_weight, v_bias]) - else: - initializers.extend([q_weight, q_bias, k_weight, v_weight, v_bias]) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_attention_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_multihead_attention( - hidden_size=768, num_heads=12, epsilon=0.000009999999747378752, add_k=False, fused=False -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("encoder_hidden_states", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - helper.make_tensor_value_info( - "present.0.encoder.key", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info( - "present.0.encoder.value", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - ] - - # Create SkipLayerNorm (since there's no Add + LayerNorm variant for this attention subgraph) - nodes = [ - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ] - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("MatMul", ["encoder_hidden_states", "k_weight"], ["k_matmul_output"], "k_path_matmul"), - helper.make_node("MatMul", ["encoder_hidden_states", "v_weight"], ["v_matmul_output"], "v_path_matmul"), - helper.make_node( - "MultiHeadAttention", - ["q_matmul_output", "k_matmul_output", "v_matmul_output", "Attention_0_qkv_bias"], - ["attn_output", "present.0.encoder.key", "present.0.encoder.value"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "q_weight"], - ["q_matmul_output"], - "q_path_matmul", - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node("MatMul", ["encoder_hidden_states", "k_weight"], ["k_matmul_output"], "k_path_matmul"), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node( - "Transpose", ["k_4d_bsnh"], ["present.0.encoder.key"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.encoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node("MatMul", ["encoder_hidden_states", "v_weight"], ["v_matmul_output"], "v_path_matmul"), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node( - "Transpose", ["v_4d_bsnh"], ["present.0.encoder.value"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.encoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", - ["gather_0_output", "unsqueeze_axes_input"], - ["unsqueeze_0_output"], - "unsqueeze_0", - ), - helper.make_node( - "Unsqueeze", - ["gather_1_output", "unsqueeze_axes_input"], - ["unsqueeze_1_output"], - "unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - # Add Q/K/V weight tensors as initializers - initializers.extend([q_weight, k_weight, v_weight]) - - if fused: - initializers.append(qkv_bias) - else: - if add_k: - initializers.extend([q_bias, k_bias, v_bias]) - else: - initializers.extend([q_bias, v_bias]) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph(nodes, "whisper_decoder_mha_graph", inputs, outputs, initializers, doc_string="whisper") - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_with_past_multihead_self_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info( - "past_key_values.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - helper.make_tensor_value_info( - "past_key_values.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - helper.make_tensor_value_info( - "present.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len + 1", head_size] - ), - helper.make_tensor_value_info( - "present.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len + 1", head_size] - ), - ] - nodes = [] - - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "MatMul_0_qkv_weight"], - ["MatMul_0_qkv_out"], - "MatMul_0", - ), - helper.make_node( - "Slice", - ["MatMul_0_qkv_out", "MatMul_0_q_start_index", "MatMul_0_k_start_index", "MatMul_0_qkv_last_axis"], - ["MatMul_0_q_out"], - "Slice_0", - ), - helper.make_node( - "Slice", - ["MatMul_0_qkv_out", "MatMul_0_k_start_index", "MatMul_0_v_start_index", "MatMul_0_qkv_last_axis"], - ["MatMul_0_k_out"], - "Slice_1", - ), - helper.make_node( - "Slice", - [ - "MatMul_0_qkv_out", - "MatMul_0_v_start_index", - "MatMul_0_end_of_qkv_index", - "MatMul_0_qkv_last_axis", - ], - ["MatMul_0_v_out"], - "Slice_2", - ), - helper.make_node( - "MultiHeadAttention", - [ - "MatMul_0_q_out", - "MatMul_0_k_out", - "MatMul_0_v_out", - "Attention_0_qkv_bias", - "", - "", - "past_key_values.0.decoder.key", - "past_key_values.0.decoder.value", - ], - ["attn_output", "present.0.decoder.key", "present.0.decoder.value"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - unidirectional=1, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "k_weight"], - ["k_matmul_output"], - "k_path_matmul", - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Concat", - ["past_key_values.0.decoder.key", "k_4d_bnsh"], - ["present.0.decoder.key"], - "concat_past_k_and_curr_k", - axis=2, - ), - helper.make_node( - "Reshape", - ["present.0.decoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "v_weight"], - ["v_matmul_output"], - "v_path_matmul", - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Concat", - ["past_key_values.0.decoder.value", "v_4d_bnsh"], - ["present.0.decoder.value"], - "concat_past_v_and_curr_v", - axis=2, - ), - helper.make_node( - "Reshape", - ["present.0.decoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", - ["gather_0_output", "num_heads_int"], - ["mul_attn_heads_output"], - "mul_num_heads", - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", - ["gather_0_output", "unsqueeze_axes_input"], - ["unsqueeze_0_output"], - "unsqueeze_0", - ), - helper.make_node( - "Unsqueeze", - ["gather_1_output", "unsqueeze_axes_input"], - ["unsqueeze_1_output"], - "unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "MatMul_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", - TensorProto.FLOAT, - [3 * hidden_size], - q_bias_data + k_bias_data + v_bias_data, - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - if fused: - # Add packed QKV weight tensor as initializer - initializers.append(qkv_weight) - - # Add Slice indices as initializers - initializers.extend( - [ - helper.make_tensor(name="MatMul_0_q_start_index", data_type=TensorProto.INT64, dims=[1], vals=[0]), - helper.make_tensor( - name="MatMul_0_k_start_index", data_type=TensorProto.INT64, dims=[1], vals=[hidden_size] - ), - helper.make_tensor( - name="MatMul_0_v_start_index", data_type=TensorProto.INT64, dims=[1], vals=[2 * hidden_size] - ), - helper.make_tensor( - name="MatMul_0_end_of_qkv_index", data_type=TensorProto.INT64, dims=[1], vals=[3 * hidden_size] - ), - helper.make_tensor(name="MatMul_0_qkv_last_axis", data_type=TensorProto.INT64, dims=[1], vals=[-1]), - ] - ) - - # Add packed QKV bias tensor as initializer - initializers.append(qkv_bias) - else: - # Add Q/K/V weight tensors as initializers - initializers.extend([q_weight, k_weight, v_weight]) - - if add_k: - initializers.extend([q_bias, k_bias, v_bias]) - else: - initializers.extend([q_bias, v_bias]) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_with_past_self_mha_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_with_past_multihead_cross_attention( - hidden_size=768, num_heads=12, epsilon=0.000009999999747378752, fused=False -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info( - "past_key_values.0.encoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - helper.make_tensor_value_info( - "past_key_values.0.encoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - # Create SkipLayerNorm (since there's no Add + LayerNorm variant for this attention subgraph) - nodes = [ - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ] - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node( - "MultiHeadAttention", - [ - "q_matmul_output", - "past_key_values.0.encoder.key", - "past_key_values.0.encoder.value", - "Attention_0_qkv_bias", - ], - ["attn_output"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "Reshape", - ["past_key_values.0.encoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - v_nodes = [ - helper.make_node( - "Reshape", - ["past_key_values.0.encoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=True) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size], zeros=True) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - q_weight, - ] - - if fused: - # Add packed QKV bias tensor as initializer - initializers.append(qkv_bias) - else: - # Add Q bias tensor as initializer - initializers.append(q_bias) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_with_past_cross_mha_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -if __name__ == "__main__": - np.random.seed(2) - num_heads = 4 - hidden_size = 64 - - model = create_whisper_encoder_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_encoder_attention_sln.onnx") - - model = create_whisper_encoder_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/encoder_attention_with_sln_fused.onnx") - - model = create_whisper_decoder_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_attention_sln.onnx") - - model = create_whisper_decoder_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/decoder_attention_with_sln_fused.onnx") - - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_mha.onnx") - - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/decoder_mha_fused.onnx") - - model = create_whisper_decoder_with_past_multihead_self_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_with_past_self_mha.onnx") - - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, fused=True - ) - onnx.save(model, "./test_data/models/whisper/decoder_with_past_self_mha_fused.onnx") - - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_with_past_cross_mha.onnx") - - model = create_whisper_decoder_with_past_multihead_cross_attention( - num_heads=num_heads, hidden_size=hidden_size, fused=True - ) - onnx.save(model, "./test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx") From 89a2ff9259a3e566fd98577aff7204e72ec19439 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Sun, 1 Jun 2025 16:16:46 -0700 Subject: [PATCH 57/57] [TensorRT EP] Address GPU bf16 support check (#24915) BF16 support is primarily available on NVIDIA GPUs with the Ampere and later architectures with compute capability of 8.0 or higher. If trt_bf16_enable = true and compute capability < 8, TRT EP will make trt_bf16_enable = false --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 5da7be0f758e0..fc8281ce51a1b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1369,6 +1369,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv max_workspace_size_ = info.max_workspace_size; fp16_enable_ = info.fp16_enable; bf16_enable_ = info.bf16_enable; + // BF16 support is primarily available on NVIDIA GPUs with the Ampere and later architectures with compute capability of 8.0 or higher. + if (bf16_enable_ && prop.major < 8) { + bf16_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_bf16_enable is set, but platform doesn't support bf16."; + } int8_enable_ = info.int8_enable; if (int8_enable_) { int8_calibration_cache_name_ = info.int8_calibration_table_name;