From 0482251e6253a76a6cb019514896bb4d79bce3ca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 09:13:32 +0000 Subject: [PATCH 01/22] Bump @babel/helpers from 7.25.6 to 7.26.10 in /js/react_native/e2e (#23993) --- js/react_native/e2e/package-lock.json | 68 +++++++++++++-------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index a537c53b5e1ba..9abe23940e0f2 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -57,14 +57,14 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -417,18 +417,18 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "license": "MIT", "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.27.1.tgz", + "integrity": "sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==", "license": "MIT", "engines": { "node": ">=6.9.0" @@ -458,25 +458,25 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.3", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.3.tgz", + "integrity": "sha512-PTNtvUQihsAsDHMOP5pfobP8C6CM4JWXmP8DrEIt46c3r2bf87Ua1zoqevsMo9g+tWDwgWrFP5EIxuBx5RudAw==", "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.2" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.3", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.3.tgz", + "integrity": "sha512-7+Ey1mAgYqFAx2h0RuoxcQT5+MlG3GTV0TQrgr7/ZliKsm/MNDxVVutlWaziMq7wJNAz8MTqz55XLpWvva6StA==", "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.2" }, "bin": { "parser": "bin/babel-parser.js" @@ -2189,14 +2189,14 @@ } }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2240,13 +2240,13 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.2", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.2.tgz", + "integrity": "sha512-ruv7Ae4J5dUYULmeXw1gmb7rYRz57OWCPM57pHojnLq/3Z1CK2lNSLTCVjxVk1F/TZHwOZZrOWi0ur95BbLxNQ==", "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -10443,9 +10443,9 @@ } }, "node_modules/picocolors": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.0.tgz", - "integrity": "sha512-TQ92mBOW0l3LeMeyLV6mzy/kWr8lkd/hp3mTg7wYK7zJhuBStmGMBG0BdeDZS/dZx1IukaX6Bk11zcln25o1Aw==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "license": "ISC" }, "node_modules/picomatch": { From a7b3b66941e9e15b14b55fd2e07fa7ebe647b43a Mon Sep 17 00:00:00 2001 From: Savelii Pototskii Date: Tue, 26 Aug 2025 03:23:11 +0500 Subject: [PATCH 02/22] Fix a typo in ORT_API_CALL macro (_stdcall) (#25834) ### Description Fixed the macro `ORT_API_CALL` by replacing `_stdcall` with `__stdcall` ### Motivation and Context Recently, I found an issue that prevents ONNX Runtime from being built using the MinGW toolchain on Windows. After investigating, I discovered that the ONNX Runtime C API header contains a typo in the `ORT_API_CALL` preprocessor macro. It is incorrectly defined as `_stdcall` instead of the correct `__stdcall` (with two leading underscores). This causes build failures on compilers like MinGW that are strict about this syntax. --- include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bedeeb972c3a7..beda0e2a9c0d0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -87,7 +87,7 @@ extern "C" { #else #define ORT_EXPORT #endif -#define ORT_API_CALL _stdcall +#define ORT_API_CALL __stdcall #define ORT_MUST_USE_RESULT #define ORTCHAR_T wchar_t #else From 8da3e40c7d4eeba8d5358497ab477db6b057fb6a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 25 Aug 2025 16:20:32 -0700 Subject: [PATCH 03/22] [build] allow custom CMAKE_C_STANDARD and CMAKE_CXX_STANDARD (#25782) ### Description allow custom CMAKE_C_STANDARD and CMAKE_CXX_STANDARD Fixes #25756 ### Motivation and Context --- cmake/CMakeLists.txt | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1f37432cc530c..98548957d0b42 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -22,7 +22,9 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "IntelLLVM") endif() # Needed for Java -set(CMAKE_C_STANDARD 99) +if (NOT CMAKE_CXX_STANDARD) + set(CMAKE_C_STANDARD 99) +endif() include(CheckCXXCompilerFlag) include(CheckLanguage) @@ -32,11 +34,13 @@ include(CheckFunctionExists) include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables -# TODO: update this once all system adapt c++20 -if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") -set(CMAKE_CXX_STANDARD 20) -else() -set(CMAKE_CXX_STANDARD 17) +if (NOT CMAKE_CXX_STANDARD) + # TODO: update this once all system adapt c++20 + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(CMAKE_CXX_STANDARD 20) + else() + set(CMAKE_CXX_STANDARD 17) + endif() endif() if (MSVC) From 5b021b24f3a131a3a49c77e099aa24a970f4b74b Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Mon, 25 Aug 2025 18:09:05 -0700 Subject: [PATCH 04/22] [webgpu] Add support for Q2 in matmulnbits (#25763) ### Description This change adds support for Q2 quantized matmulnbits, in webgpu. ### Motivation and Context An alternate way to support bitnets is through adding support for lower bits in matmulnbits, this reuses our shaders and is more maintainable than a separate op. The model size grows a bit however for a 2B parameter model using 1.58bpw vs 2bpw the size difference is just 100MB. The simpler dequantization also improves perf, on an Intel XE matmul looks to be 20% faster using q2 weights vs q4 weights for the same matrix dimensions. Q2 version of the bitnet model is here https://huggingface.co/sushraja/bitnet-b1.58-2B-4T-fp16-onnx/tree/main/bitnet_q2 --- .../quantization/dp4a_matmul.wgsl.template | 21 +- .../dp4a_matmul_common.wgsl.template | 286 +++++++++++++++++- .../webgpu/quantization/dp4a_matmul_nbits.cc | 12 +- .../webgpu/quantization/dp4a_matmul_nbits.h | 12 +- .../dp4a_matmul_small_m.wgsl.template | 22 +- .../webgpu/quantization/matmul_nbits.cc | 21 +- .../webgpu/quantization/matmul_nbits.h | 7 +- .../quantization/matmul_nbits.wgsl.template | 41 +++ .../matmul_nbits_zero_pt.wgsl.template | 2 + .../test/contrib_ops/matmul_2bits_test.cc | 58 +++- 10 files changed, 441 insertions(+), 41 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index e4e3730eba808..bea6588ea72eb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -108,7 +108,26 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) } #endif +#if n_bits == 2 + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); + let block_idx = kidx_v/(block_size/16); + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; + } +#endif + $MAIN { +#if n_bits == 2 + LoadDequantizationTable(local_idx); + workgroupBarrier(); +#endif // During the load phase we use all 256 threads to load 64 rows of A/B. // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size; @@ -131,7 +150,7 @@ $MAIN { var lane_output2: vec4; var lane_output3: vec4; var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. + // K's vectorization is 16 items per index. See input_a/input_b. // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is // k tile size is 32. In vectorized space that is 32/16 = 2. for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template index 38fe0388c5954..186685b9a8dd4 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template @@ -7,6 +7,7 @@ #include "quantization/matmul_nbits_zero_pt.wgsl.template" #if n_bits == 4 + alias mul_precision = output_element_t; fn DequantizedFrom4BitsTo8Bits(in: vec2, zero: i32) -> vec4 { var out = vec4(0); @@ -23,6 +24,9 @@ #endif #if n_bits == 8 + // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. + // Then do the scale. Finally, convert to output element type. + alias mul_precision = f32; fn AlignWithZeroPoint(in: vec4) -> vec4 { var out = vec4(0); @@ -34,8 +38,282 @@ } #endif -// For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. -// Then do the scale. Finally, convert to output element type. +#if n_bits == 2 + alias mul_precision = output_element_t; + const lut_size = 256; + var shm_dequantization_table : array; + const q2_dequantization_table = array( + 0xFEFEFEFE, + 0xFEFEFEFF, + 0xFEFEFE00, + 0xFEFEFE01, + 0xFEFEFFFE, + 0xFEFEFFFF, + 0xFEFEFF00, + 0xFEFEFF01, + 0xFEFE00FE, + 0xFEFE00FF, + 0xFEFE0000, + 0xFEFE0001, + 0xFEFE01FE, + 0xFEFE01FF, + 0xFEFE0100, + 0xFEFE0101, + 0xFEFFFEFE, + 0xFEFFFEFF, + 0xFEFFFE00, + 0xFEFFFE01, + 0xFEFFFFFE, + 0xFEFFFFFF, + 0xFEFFFF00, + 0xFEFFFF01, + 0xFEFF00FE, + 0xFEFF00FF, + 0xFEFF0000, + 0xFEFF0001, + 0xFEFF01FE, + 0xFEFF01FF, + 0xFEFF0100, + 0xFEFF0101, + 0xFE00FEFE, + 0xFE00FEFF, + 0xFE00FE00, + 0xFE00FE01, + 0xFE00FFFE, + 0xFE00FFFF, + 0xFE00FF00, + 0xFE00FF01, + 0xFE0000FE, + 0xFE0000FF, + 0xFE000000, + 0xFE000001, + 0xFE0001FE, + 0xFE0001FF, + 0xFE000100, + 0xFE000101, + 0xFE01FEFE, + 0xFE01FEFF, + 0xFE01FE00, + 0xFE01FE01, + 0xFE01FFFE, + 0xFE01FFFF, + 0xFE01FF00, + 0xFE01FF01, + 0xFE0100FE, + 0xFE0100FF, + 0xFE010000, + 0xFE010001, + 0xFE0101FE, + 0xFE0101FF, + 0xFE010100, + 0xFE010101, + 0xFFFEFEFE, + 0xFFFEFEFF, + 0xFFFEFE00, + 0xFFFEFE01, + 0xFFFEFFFE, + 0xFFFEFFFF, + 0xFFFEFF00, + 0xFFFEFF01, + 0xFFFE00FE, + 0xFFFE00FF, + 0xFFFE0000, + 0xFFFE0001, + 0xFFFE01FE, + 0xFFFE01FF, + 0xFFFE0100, + 0xFFFE0101, + 0xFFFFFEFE, + 0xFFFFFEFF, + 0xFFFFFE00, + 0xFFFFFE01, + 0xFFFFFFFE, + 0xFFFFFFFF, + 0xFFFFFF00, + 0xFFFFFF01, + 0xFFFF00FE, + 0xFFFF00FF, + 0xFFFF0000, + 0xFFFF0001, + 0xFFFF01FE, + 0xFFFF01FF, + 0xFFFF0100, + 0xFFFF0101, + 0xFF00FEFE, + 0xFF00FEFF, + 0xFF00FE00, + 0xFF00FE01, + 0xFF00FFFE, + 0xFF00FFFF, + 0xFF00FF00, + 0xFF00FF01, + 0xFF0000FE, + 0xFF0000FF, + 0xFF000000, + 0xFF000001, + 0xFF0001FE, + 0xFF0001FF, + 0xFF000100, + 0xFF000101, + 0xFF01FEFE, + 0xFF01FEFF, + 0xFF01FE00, + 0xFF01FE01, + 0xFF01FFFE, + 0xFF01FFFF, + 0xFF01FF00, + 0xFF01FF01, + 0xFF0100FE, + 0xFF0100FF, + 0xFF010000, + 0xFF010001, + 0xFF0101FE, + 0xFF0101FF, + 0xFF010100, + 0xFF010101, + 0x00FEFEFE, + 0x00FEFEFF, + 0x00FEFE00, + 0x00FEFE01, + 0x00FEFFFE, + 0x00FEFFFF, + 0x00FEFF00, + 0x00FEFF01, + 0x00FE00FE, + 0x00FE00FF, + 0x00FE0000, + 0x00FE0001, + 0x00FE01FE, + 0x00FE01FF, + 0x00FE0100, + 0x00FE0101, + 0x00FFFEFE, + 0x00FFFEFF, + 0x00FFFE00, + 0x00FFFE01, + 0x00FFFFFE, + 0x00FFFFFF, + 0x00FFFF00, + 0x00FFFF01, + 0x00FF00FE, + 0x00FF00FF, + 0x00FF0000, + 0x00FF0001, + 0x00FF01FE, + 0x00FF01FF, + 0x00FF0100, + 0x00FF0101, + 0x0000FEFE, + 0x0000FEFF, + 0x0000FE00, + 0x0000FE01, + 0x0000FFFE, + 0x0000FFFF, + 0x0000FF00, + 0x0000FF01, + 0x000000FE, + 0x000000FF, + 0x00000000, + 0x00000001, + 0x000001FE, + 0x000001FF, + 0x00000100, + 0x00000101, + 0x0001FEFE, + 0x0001FEFF, + 0x0001FE00, + 0x0001FE01, + 0x0001FFFE, + 0x0001FFFF, + 0x0001FF00, + 0x0001FF01, + 0x000100FE, + 0x000100FF, + 0x00010000, + 0x00010001, + 0x000101FE, + 0x000101FF, + 0x00010100, + 0x00010101, + 0x01FEFEFE, + 0x01FEFEFF, + 0x01FEFE00, + 0x01FEFE01, + 0x01FEFFFE, + 0x01FEFFFF, + 0x01FEFF00, + 0x01FEFF01, + 0x01FE00FE, + 0x01FE00FF, + 0x01FE0000, + 0x01FE0001, + 0x01FE01FE, + 0x01FE01FF, + 0x01FE0100, + 0x01FE0101, + 0x01FFFEFE, + 0x01FFFEFF, + 0x01FFFE00, + 0x01FFFE01, + 0x01FFFFFE, + 0x01FFFFFF, + 0x01FFFF00, + 0x01FFFF01, + 0x01FF00FE, + 0x01FF00FF, + 0x01FF0000, + 0x01FF0001, + 0x01FF01FE, + 0x01FF01FF, + 0x01FF0100, + 0x01FF0101, + 0x0100FEFE, + 0x0100FEFF, + 0x0100FE00, + 0x0100FE01, + 0x0100FFFE, + 0x0100FFFF, + 0x0100FF00, + 0x0100FF01, + 0x010000FE, + 0x010000FF, + 0x01000000, + 0x01000001, + 0x010001FE, + 0x010001FF, + 0x01000100, + 0x01000101, + 0x0101FEFE, + 0x0101FEFF, + 0x0101FE00, + 0x0101FE01, + 0x0101FFFE, + 0x0101FFFF, + 0x0101FF00, + 0x0101FF01, + 0x010100FE, + 0x010100FF, + 0x01010000, + 0x01010001, + 0x010101FE, + 0x010101FF, + 0x01010100, + 0x01010101); + fn LoadDequantizationTable(local_idx:u32) + { + // Move dequantization table into on chip memory. + shm_dequantization_table[local_idx] = q2_dequantization_table[local_idx]; + } + fn DequantizedFrom2BitsTo8Bits(in: u32) -> vec4 + { + let unpacked = unpack4xU8(in); + return vec4(shm_dequantization_table[unpacked[0]], + shm_dequantization_table[unpacked[1]], + shm_dequantization_table[unpacked[2]], + shm_dequantization_table[unpacked[3]]); + } +#endif + #if has_zero_points && n_bits == 8 // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. // To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4(unpack4xU8(xxx)) - vec4(128))| to process the b data. In SDP8AI, we use the @@ -61,7 +339,7 @@ local_sum += dot4I8Packed(a2[3], b2[3]); dequantized_a_sum += vec4(unpack4xI8(a2[3])); local_sum -= dot(dequantized_a_sum, vec4(bias_zero)); - return output_element_t(f32(local_sum) * f32(scale)); + return output_element_t(mul_precision(local_sum) * mul_precision(scale)); } #else // Scaled dot product of 8 packed unsigned integers. @@ -75,6 +353,6 @@ local_sum += dot4I8Packed(a2[1], b2[1]); local_sum += dot4I8Packed(a2[2], b2[2]); local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(f32(local_sum) * f32(scale)); + return output_element_t(mul_precision(local_sum) * mul_precision(scale)); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 6d0107a779f33..61764979a5dd6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -51,6 +51,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_)); @@ -86,6 +87,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor .AddUniformVariable({M * K / kU32Components}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); const bool has_zero_points = zero_points != nullptr; + const bool single_scale_weights = (block_size == K * N); if (M < min_M_for_tile_optimization) { uint32_t tile_size_k_vec = 16; uint32_t tile_size_n = 32; @@ -94,18 +96,18 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor tile_size_k_vec = 32; tile_size_n = 4; } - - DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points}; + const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components); + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights}; uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(b_components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points); + .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } @@ -121,7 +123,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kVec2Components * kU32Components : kVec4Components * kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast((nbits / 2) * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 5c07885011aac..470a4e187f37a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -43,11 +43,12 @@ class DP4AMatMulNBitsProgram final : public Program { class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBitsSmallMProgram"}, - tile_size_k_vec_(tile_size_k_vec), - tile_size_(tile_size), - nbits_(nbits), - has_zero_points_(has_zero_points) {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"}, + tile_size_k_vec_(tile_size_k_vec), + tile_size_(tile_size), + nbits_(nbits), + has_zero_points_(has_zero_points), + single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -64,6 +65,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program(bits_); if (has_zero_points) { ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); + ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization."); } MatMulComputeHelper helper; @@ -124,10 +127,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t N = onnxruntime::narrow(helper.N()); const uint32_t K = onnxruntime::narrow(helper.K()); const uint32_t block_size = onnxruntime::narrow(block_size_); - const uint32_t nbits = onnxruntime::narrow(bits_); - const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; - const uint32_t blob_size = (block_size / 8) * nbits; + // Special case matrix used by bitnets where there is a single scale for the entire + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * nbits; const uint32_t blob_size_in_words = blob_size / 4; const uint32_t components_a = GetMaxComponents(K); const uint32_t components_b = GetMaxComponents(blob_size_in_words); @@ -157,6 +162,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && + nbits != 2 && M >= kMinMForTileOptimization; if (use_wide_tile_program) { // Enforce output components to 1. @@ -212,7 +218,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t kU32Components = 4; uint32_t components_b_with_u32 = components_b * kU32Components; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; - MatMulNBitsProgram program{tile_size, nbits, has_zero_points}; + uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + MatMulNBitsProgram program{tile_size, nbits, has_zero_points, single_scale_weights}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program @@ -220,8 +227,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {scales, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {n_blocks_per_col * blob_size / components_b_with_u32}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) - .CacheHint(nbits, has_zero_points); + .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {K_of_b}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) + .CacheHint(nbits, has_zero_points, single_scale_weights); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index aabc73ca05d03..1f7bd16d9cb6f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -37,7 +37,7 @@ class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points) {} + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -55,6 +55,7 @@ class MatMulNBitsProgram final : public Program { uint32_t tile_size_; uint32_t nbits_; bool has_zero_points_; + bool single_scale_weights_; }; class MatMulNBits final : public WebGpuKernel { @@ -65,8 +66,8 @@ class MatMulNBits final : public WebGpuKernel { block_size_ = info.GetAttr("block_size"); bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); - ORT_ENFORCE(bits_ == 4 || bits_ == 8, - "Only 4b/8b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, + "Only 4b/8b/2b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 5fd0a9b3dd788..aba6e3d57c72a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -6,6 +6,7 @@ #param component_b #param elements_in_value_b #param n_bits +#param single_scale_weights #param sub_tile_count #param tile_size_k_vec #param tile_size_k @@ -35,6 +36,12 @@ $MAIN { let idx = local_idx % tile_size_k_vec; let idy = local_idx / tile_size_k_vec; +#if single_scale_weights + let block_idx = 0; + let scale_b = scales_b[0]; + let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col); +#endif + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) @@ -49,9 +56,11 @@ $MAIN { var k_offset = kidx / elements_in_value_b + idx; if (b_global < uniforms.N && k_offset < uniforms.K_of_b) { +#if !single_scale_weights let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx]; let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); +#endif var b_value = input_b[b_global * uniforms.K_of_b + k_offset]; #if n_bits == 4 @@ -91,6 +100,38 @@ $MAIN { a_offset += 1; #endif } +#elif n_bits == 2 + var sum = output_element_t(0); + var a_offset = idx * (16 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_data_0 = vec4(unpack4xU8(b_value[i] & 0x03030303u)) - vec4(zero); + let b_data_1 = vec4(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4(zero); + let b_data_2 = vec4(unpack4xU8((b_value[i] >> 4) & 0x03030303u)) - vec4(zero); + let b_data_3 = vec4(unpack4xU8((b_value[i] >> 6) & 0x03030303u)) - vec4(zero); + + let b0 = vec4(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b; + let b1 = vec4(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b; + let b2 = vec4(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b; + let b3 = vec4(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b; + +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) + + dot(vec4(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) + + dot(vec4(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3); + a_offset += 16; +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) + + dot(vec4(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3); + a_offset += 8; +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) + + dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3); + a_offset += 4; +#endif + } #endif inter_results[local_row_offset + idy][idx] += sum; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template index 0da5bd09609af..9135708adf153 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template @@ -17,6 +17,8 @@ #elif n_bits == 8 const default_zero_point = 128; const bit_mask = 0xFFu; +#elif n_bits == 2 + const default_zero_point = 2; #endif #if has_zero_points diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 884922ec5c098..57b1edfd0edce 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -198,6 +198,9 @@ void RunTest2Bits(const TestOptions2Bits& opts) { std::vector> execution_providers; if constexpr (std::is_same::value) { execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); } @@ -244,22 +247,47 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // namespace -TEST(MatMulNBits, Float32_2Bits_Accuracy0) { - // Currently, only fallback option enabled for 2bit datatypes - // where the 2bits are dequantized to fp32 - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); +template +struct TypedTestParams { + static constexpr int batch_size = BatchSize; + static constexpr int M = MVal; + static constexpr int N = NVal; + static constexpr int K = KVal; +}; + +using TestTypes = ::testing::Types< + TypedTestParams<1, 1, 16, 16>, + TypedTestParams<1, 2, 16, 16>, + TypedTestParams<1, 32, 16, 16>, + TypedTestParams<1, 32, 32, 16>, + TypedTestParams<1, 32, 16, 128>, + TypedTestParams<1, 288, 16, 16>, + TypedTestParams<4, 1, 16, 16>, + TypedTestParams<4, 2, 16, 16>, + TypedTestParams<4, 32, 16, 16>, + TypedTestParams<4, 32, 32, 16>, + TypedTestParams<4, 32, 16, 128>, + TypedTestParams<4, 288, 16, 16>>; + +template +class MatMulNBits : public ::testing::Test { + public: + static constexpr int batch_size = T::batch_size; + static constexpr int M = T::M; + static constexpr int N = T::N; + static constexpr int K = T::K; +}; + +TYPED_TEST_SUITE(MatMulNBits, TestTypes); + +TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy0) { + TestMatMul2BitsTyped(); +} + +TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy4) { + TestMatMul2BitsTyped(); } + } // namespace test } // namespace onnxruntime From cf05366785adc3d59ff026a58496a5e8864bd024 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 26 Aug 2025 17:00:22 +0800 Subject: [PATCH 05/22] [webgpu ] Optimize flash attention for Nvidia (#25777) In the flash attention algorithm, each thread in a subgroup needs to access the same range (0-15) of data in workgroup memory `q_tile` and `v_tile`. If we use `subgroupShuffle`, there will be bank conflicts for `var k_local = k_tile[capped_sg_id][i];` since the sg_size is 32 and thread16~thread31 are accessing the same bank address. To avoid the bank conflicts, we can directly access the same address in workgroup memory by all threads which is a broadcast and well optimized in the NV GPUs. See ~10% improvement for phi4 prefill (1K) in NV RTX 2000 Ada. And as the input gets longer(total_sequence_length), the optimization effect gets better (~12% for 2K). Before ``` Batch size: 1, prompt tokens: 1000, tokens to generate: 128 Prompt processing (time to first token): avg (us): 2.0991e+06 avg (tokens/s): 476.394 p50 (us): 2.08457e+06 stddev (us): 36140.3 n: 5 * 1000 token(s) Token generation: avg (us): 25477.8 avg (tokens/s): 39.2498 p50 (us): 25028.2 stddev (us): 4841.89 n: 635 * 1 token(s) ``` After ``` Batch size: 1, prompt tokens: 1000, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.91138e+06 avg (tokens/s): 523.183 p50 (us): 1.92379e+06 stddev (us): 44768 n: 5 * 1000 token(s) Token generation: avg (us): 25237.2 avg (tokens/s): 39.624 p50 (us): 24860.9 stddev (us): 4874.52 n: 635 * 1 token(s) ``` --- .../webgpu/bert/flash_attention.cc | 7 +- .../contrib_ops/webgpu/bert/flash_attention.h | 7 +- .../webgpu/bert/flash_attention.wgsl.template | 134 +++++++++++------- 3 files changed, 93 insertions(+), 55 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index ef72eea76b2d3..ab611a8e5a7c0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -139,6 +139,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_)); } @@ -283,6 +284,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t tile_size = 64; bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; + bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); FlashAttentionProgram program{"FlashAttention", has_attention_bias, @@ -290,7 +292,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co is_fp16, parameters.head_size_, parameters.num_heads_, - parameters.is_unidirectional_}; + parameters.is_unidirectional_, + is_nvidia}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -303,7 +306,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9839b43ee8a69..c75494df253c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -40,14 +40,16 @@ class FlashAttentionProgram final : public Program { bool is_fp16, int qkv_head_size, int qkv_num_heads, - bool is_unidirectional) + bool is_unidirectional, + bool is_nvidia) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), is_fp16_(is_fp16), qkv_head_size_(qkv_head_size), qkv_num_heads_(qkv_num_heads), - is_unidirectional_(is_unidirectional) { + is_unidirectional_(is_unidirectional), + is_nvidia_(is_nvidia) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -67,6 +69,7 @@ class FlashAttentionProgram final : public Program { int qkv_head_size_; int qkv_num_heads_; bool is_unidirectional_; + bool is_nvidia_; }; class FlashAttentionDecodeQKTProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index a8c28388292f9..0674702bd6030 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -3,6 +3,7 @@ #param is_fp16 #param is_qualcomm #param is_unidirectional +#param prefer_subgroupshuffle #param qkv_head_size #param qkv_num_heads @@ -110,6 +111,22 @@ fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> } #endif +fn fetchKTile(k_idx: u32, vec_idx: u32, k_val: q_value_t) -> q_value_t { +#if prefer_subgroupshuffle + return subgroupShuffle(k_val, k_idx); +#else + return k_tile[k_idx][vec_idx]; +#endif +} + +fn fetchVTile(k_idx: u32, vec_idx: u32, v_val: q_value_t) -> q_value_t { +#if prefer_subgroupshuffle + return subgroupShuffle(v_val, k_idx); +#else + return v_tile[k_idx][vec_idx]; +#endif +} + $MAIN { let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); let capped_sg_id = min(sg_id, max_k_step - 1u); @@ -149,47 +166,54 @@ $MAIN { var qk_4 : vec4; if (sg_size > 8) { for (var i : u32 = 0u; i < head_size_vec; i++) { -#if is_qualcomm +#if prefer_subgroupshuffle + #if is_qualcomm var k_local = q_value_t(0); if (sg_id < max_k_step) { k_local = k_tile[sg_id][i]; } -#else + #else var k_local = k_tile[capped_sg_id][i]; + #endif +#else + var k_local = q_value_t(0); #endif var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); - qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); - qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); - qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); - qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); - qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); - qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); - qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); - qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); + qk_1[0] += dot(q_own, fetchKTile(0, i, k_local)); + qk_1[1] += dot(q_own, fetchKTile(1, i, k_local)); + qk_1[2] += dot(q_own, fetchKTile(2, i, k_local)); + qk_1[3] += dot(q_own, fetchKTile(3, i, k_local)); + qk_2[0] += dot(q_own, fetchKTile(4, i, k_local)); + qk_2[1] += dot(q_own, fetchKTile(5, i, k_local)); + qk_2[2] += dot(q_own, fetchKTile(6, i, k_local)); + qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); + qk_3[0] += dot(q_own, fetchKTile(8, i, k_local)); + qk_3[1] += dot(q_own, fetchKTile(9, i, k_local)); + qk_3[2] += dot(q_own, fetchKTile(10, i, k_local)); + qk_3[3] += dot(q_own, fetchKTile(11, i, k_local)); + qk_4[0] += dot(q_own, fetchKTile(12, i, k_local)); + qk_4[1] += dot(q_own, fetchKTile(13, i, k_local)); + qk_4[2] += dot(q_own, fetchKTile(14, i, k_local)); + qk_4[3] += dot(q_own, fetchKTile(15, i, k_local)); } } else { for (var i : u32 = 0u; i < head_size_vec; i++) { +#if prefer_subgroupshuffle var k_local = k_tile[capped_sg_id][i]; +#else + var k_local = q_value_t(0); +#endif var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + qk_1[0] += dot(q_own, fetchKTile(0, i, k_local)); + qk_1[1] += dot(q_own, fetchKTile(1, i, k_local)); + qk_1[2] += dot(q_own, fetchKTile(2, i, k_local)); + qk_1[3] += dot(q_own, fetchKTile(3, i, k_local)); + qk_2[0] += dot(q_own, fetchKTile(4, i, k_local)); + qk_2[1] += dot(q_own, fetchKTile(5, i, k_local)); + qk_2[2] += dot(q_own, fetchKTile(6, i, k_local)); + qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start + 4, head_idx); if (sg_size > 8) { @@ -326,36 +350,44 @@ $MAIN { #else if (sg_size > 8) { for (var i : u32 = 0; i < head_size_vec; i++) { + #if prefer_subgroupshuffle var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; + #else + var val = q_value_t(0); + #endif + var sum = fetchVTile(0, i, val) * qk_1[0]; + sum += fetchVTile(1, i, val) * qk_1[1]; + sum += fetchVTile(2, i, val) * qk_1[2]; + sum += fetchVTile(3, i, val) * qk_1[3]; + sum += fetchVTile(4, i, val) * qk_2[0]; + sum += fetchVTile(5, i, val) * qk_2[1]; + sum += fetchVTile(6, i, val) * qk_2[2]; + sum += fetchVTile(7, i, val) * qk_2[3]; + sum += fetchVTile(8, i, val) * qk_3[0]; + sum += fetchVTile(9, i, val) * qk_3[1]; + sum += fetchVTile(10, i, val) * qk_3[2]; + sum += fetchVTile(11, i, val) * qk_3[3]; + sum += fetchVTile(12, i, val) * qk_4[0]; + sum += fetchVTile(13, i, val) * qk_4[1]; + sum += fetchVTile(14, i, val) * qk_4[2]; + sum += fetchVTile(15, i, val) * qk_4[3]; o_tile[i] = o_tile[i] * o_ratio + sum; } } else { for (var i : u32 = 0; i < head_size_vec; i++) { + #if prefer_subgroupshuffle var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; + #else + var val = q_value_t(0); + #endif + var sum = fetchVTile(0, i, val) * qk_1[0]; + sum += fetchVTile(1, i, val) * qk_1[1]; + sum += fetchVTile(2, i, val) * qk_1[2]; + sum += fetchVTile(3, i, val) * qk_1[3]; + sum += fetchVTile(4, i, val) * qk_2[0]; + sum += fetchVTile(5, i, val) * qk_2[1]; + sum += fetchVTile(6, i, val) * qk_2[2]; + sum += fetchVTile(7, i, val) * qk_2[3]; o_tile[i] = o_tile[i] * o_ratio + sum; } } From 17ede5a82b26500ffe9a143dfabae967bbac3f81 Mon Sep 17 00:00:00 2001 From: Arseny Date: Tue, 26 Aug 2025 20:07:29 +0300 Subject: [PATCH 06/22] safeint.h: quelch gcc's -Wreturn-type (#25655) Quoting cppreference.com: ``` (the [[noreturn]] attribute) Indicates that the function will not return control flow to the calling function after it finishes (e.g. functions that terminate the application, throw exceptions, loop indefinitely, etc.). This attribute applies to the name of the function being declared in function declarations only. If a function previously declared with `[[noreturn]]` is invoked and that invocation eventually returns, the behavior is runtime-undefined. ``` The `SafeIntOn*` member functions immediately throw, so if they are used in a function with non-void return type, g++ 14 issues a warning that there exist control paths in the function where no value is returned. Fix this by marking the member functions explicitly noreturn. This is needed so onnxruntime builds correctly with `-Wall -Wextra`. --- onnxruntime/core/common/safeint.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/common/safeint.h b/onnxruntime/core/common/safeint.h index 3ee70f369b65d..6aba5871ac62e 100644 --- a/onnxruntime/core/common/safeint.h +++ b/onnxruntime/core/common/safeint.h @@ -13,11 +13,11 @@ class SafeIntExceptionHandler; template <> class SafeIntExceptionHandler { public: - static void SafeIntOnOverflow() { + [[noreturn]] static void SafeIntOnOverflow() { ORT_THROW("Integer overflow"); } - static void SafeIntOnDivZero() { + [[noreturn]] static void SafeIntOnDivZero() { ORT_THROW("Divide by zero"); } }; From a6c92bd3e62272ed0284acd84b161d7e1c80b235 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 10:09:28 -0700 Subject: [PATCH 07/22] Fix focus contrast ratios for accessibility compliance (WCAG 2.1 AA) (#25832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR addresses accessibility issues with focus indicators on the ONNX Runtime website documentation where contrast ratios were insufficient for keyboard navigation users. The accessibility audit revealed that focus states for key navigation elements like "Learn more about ONNX Runtime & Generative AI", "Quickstart", "Tutorials", "Install ONNX Runtime", and "Hardware Acceleration" had contrast ratios as low as 1.152:1, well below the WCAG 2.1 AA requirement of 3:1 for UI components. ## Changes Made ### 1. Enhanced List Group Item Focus Contrast - **Before**: `color: #555` on `background-color: #f5f5f5` (6.8:1 ratio) - **After**: `color: #333` on `background-color: #f5f5f5` (**11.6:1 ratio**) ### 2. Improved Info List Group Item Focus Contrast - **Before**: `color: #31708f` on `background-color: #c4e3f3` (4.1:1 ratio) - **After**: `color: #1e4a5f` on `background-color: #c4e3f3` (**7.1:1 ratio**) ### 3. Added Visible Focus Indicators for Form Inputs Previously, search and filter inputs only removed the default outline (`outline: 0`) without providing alternative focus indicators, making them inaccessible to keyboard users. - **Added**: `border: 2px solid #0050C5` and `background-color: #f8f9fa` on focus - **Contrast ratio**: **6.7:1** (exceeds requirements) ## Accessibility Compliance All changes now exceed WCAG 2.1 AA standards: - ✅ **3:1 minimum** for UI components and focus indicators - ✅ **4.5:1 minimum** for normal text (all exceed 7:1) - ✅ **Keyboard navigation** fully supported with visible focus indicators - ✅ **Screen reader compatibility** improved with clear focus states ## Impact - Low vision users can now clearly see focused elements during keyboard navigation - All mentioned navigation elements meet accessibility standards - No functionality broken - purely visual accessibility enhancements - Compliance with MAS 1.4.11 Non-text Contrast requirements ## Files Modified - `csharp/ApiDocs/_exported_templates/default/styles/docfx.css` - Enhanced input focus indicators - `csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css` - Improved text contrast ratios Fixes #24995. --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxruntime/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: MaanavD <24942306+MaanavD@users.noreply.github.com> --- csharp/ApiDocs/_exported_templates/default/styles/docfx.css | 4 ++++ .../_exported_templates/default/styles/docfx.vendor.css | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.css b/csharp/ApiDocs/_exported_templates/default/styles/docfx.css index 64dcde3385eff..0d3cf530038d9 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.css +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.css @@ -323,6 +323,8 @@ article section { } .docs-search > .search-query:focus { outline: 0; + border: 2px solid #0050C5; + background-color: #f8f9fa; } .search-results-frame { clear: both; @@ -597,6 +599,8 @@ body .toc{ } .toc-filter > input:focus { outline: 0; + border: 2px solid #0050C5; + background-color: #f8f9fa; } .toc-filter > .filter-icon { position: absolute; diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css b/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css index 609602eb134dd..0c9f7c0c0aec3 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css @@ -1220,7 +1220,7 @@ to{background-position:0 0} .list-group-item.active .list-group-item-text,.list-group-item.active:focus .list-group-item-text,.list-group-item.active:hover .list-group-item-text{color:#c7ddef} a.list-group-item,button.list-group-item{color:#555} a.list-group-item .list-group-item-heading,button.list-group-item .list-group-item-heading{color:#333} -a.list-group-item:focus,a.list-group-item:hover,button.list-group-item:focus,button.list-group-item:hover{color:#555;text-decoration:none;background-color:#f5f5f5} +a.list-group-item:focus,a.list-group-item:hover,button.list-group-item:focus,button.list-group-item:hover{color:#333;text-decoration:none;background-color:#f5f5f5} button.list-group-item{width:100%;text-align:left} .list-group-item-success{color:#3c763d;background-color:#dff0d8} a.list-group-item-success,button.list-group-item-success{color:#3c763d} @@ -1230,7 +1230,7 @@ a.list-group-item-success.active,a.list-group-item-success.active:focus,a.list-g .list-group-item-info{color:#31708f;background-color:#d9edf7} a.list-group-item-info,button.list-group-item-info{color:#31708f} a.list-group-item-info .list-group-item-heading,button.list-group-item-info .list-group-item-heading{color:inherit} -a.list-group-item-info:focus,a.list-group-item-info:hover,button.list-group-item-info:focus,button.list-group-item-info:hover{color:#31708f;background-color:#c4e3f3} +a.list-group-item-info:focus,a.list-group-item-info:hover,button.list-group-item-info:focus,button.list-group-item-info:hover{color:#1e4a5f;background-color:#c4e3f3} a.list-group-item-info.active,a.list-group-item-info.active:focus,a.list-group-item-info.active:hover,button.list-group-item-info.active,button.list-group-item-info.active:focus,button.list-group-item-info.active:hover{color:#fff;background-color:#31708f;border-color:#31708f} .list-group-item-warning{color:#8a6d3b;background-color:#fcf8e3} a.list-group-item-warning,button.list-group-item-warning{color:#8a6d3b} From ce9060936201ef979a5075cd825a5942dec433ee Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 26 Aug 2025 10:09:39 -0700 Subject: [PATCH 08/22] Fix keyboard navigation accessibility for DocFX tab controls (#25819) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DocFX tab controls on onnxruntime.ai were not accessible via keyboard navigation, violating MAS 2.1.1 keyboard accessibility requirements. Users could not navigate between language tabs (Python, C#, Java, JavaScript, C++) using keyboard-only input. ## Problem The existing implementation in `docfx.js` only handled mouse click events but lacked keyboard event handlers. This prevented keyboard users from: - Navigating between tabs using arrow keys - Activating tabs using Enter/Space keys - Jumping to first/last tabs using Home/End keys ## Solution Added comprehensive keyboard navigation support following the WAI-ARIA tabs design pattern: ```javascript // Added keyboard event listener alongside existing click handler container.addEventListener('keydown', function (event) { return handleKeyDown(event, state); }); ``` The `handleKeyDown` function implements: - **Arrow key navigation**: Left/Right and Up/Down keys move focus between tabs with wrapping - **Tab activation**: Enter and Space keys activate the focused tab - **Quick navigation**: Home/End keys jump to first/last tabs - **Proper focus management**: Only the active tab has `tabIndex="0"`, others have `tabIndex="-1"` - **Event handling**: `preventDefault()` and `stopPropagation()` for handled keys ## Accessibility Features - Follows WAI-ARIA tabs pattern specifications - Maintains proper ARIA attributes (`role="tab"`, `aria-selected`, etc.) - Provides visual focus indicators via existing CSS - Supports both horizontal and vertical arrow key navigation - Implements circular navigation (wrapping at boundaries) ## Testing Validated functionality with comprehensive keyboard navigation tests: - ✅ Arrow keys navigate between tabs with proper wrapping - ✅ Enter/Space keys activate focused tabs and switch content panels - ✅ Home/End keys jump to first/last tabs correctly - ✅ Focus management works with proper `tabIndex` handling - ✅ Visual feedback shows focused vs selected tab states This ensures keyboard users can fully access all tab functionality without requiring mouse interaction. Fixes #24997. --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: MaanavD <24942306+MaanavD@users.noreply.github.com> --- .../default/styles/docfx.js | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.js b/csharp/ApiDocs/_exported_templates/default/styles/docfx.js index b6167dbd6193d..6412c718be37d 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.js +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.js @@ -802,6 +802,7 @@ $(function () { } } container.addEventListener('click', function (event) { return handleClick(event, state); }); + container.addEventListener('keydown', function (event) { return handleKeyDown(event, state); }); if (state.groups.length === 0) { return state; } @@ -820,6 +821,7 @@ $(function () { while (li) { var a = li.firstElementChild; a.setAttribute(contentAttrs.name, 'tab'); + a.setAttribute('role', 'tab'); var dataTab = a.getAttribute('data-tab').replace(/\+/g, ' '); a.setAttribute('data-tab', dataTab); var section = element.querySelector("[id=\"" + a.getAttribute('aria-controls') + "\"]"); @@ -915,6 +917,91 @@ $(function () { } } + function handleKeyDown(event, state) { + var info = getTabInfoFromEvent(event); + if (info === null) { + return; + } + + var handled = false; + var tabGroup = info.group; + var currentTabIndex = tabGroup.tabs.findIndex(function(tab) { return tab.a === info.anchor; }); + + switch (event.key) { + case 'ArrowLeft': + case 'ArrowUp': + // Move to previous tab + handled = true; + var prevIndex = currentTabIndex - 1; + if (prevIndex < 0) { + prevIndex = tabGroup.tabs.length - 1; + } + while (prevIndex !== currentTabIndex && !tabGroup.tabs[prevIndex].visible) { + prevIndex--; + if (prevIndex < 0) { + prevIndex = tabGroup.tabs.length - 1; + } + } + if (tabGroup.tabs[prevIndex].visible) { + tabGroup.tabs[prevIndex].focus(); + } + break; + + case 'ArrowRight': + case 'ArrowDown': + // Move to next tab + handled = true; + var nextIndex = currentTabIndex + 1; + if (nextIndex >= tabGroup.tabs.length) { + nextIndex = 0; + } + while (nextIndex !== currentTabIndex && !tabGroup.tabs[nextIndex].visible) { + nextIndex++; + if (nextIndex >= tabGroup.tabs.length) { + nextIndex = 0; + } + } + if (tabGroup.tabs[nextIndex].visible) { + tabGroup.tabs[nextIndex].focus(); + } + break; + + case 'Home': + // Move to first visible tab + handled = true; + for (var i = 0; i < tabGroup.tabs.length; i++) { + if (tabGroup.tabs[i].visible) { + tabGroup.tabs[i].focus(); + break; + } + } + break; + + case 'End': + // Move to last visible tab + handled = true; + for (var i = tabGroup.tabs.length - 1; i >= 0; i--) { + if (tabGroup.tabs[i].visible) { + tabGroup.tabs[i].focus(); + break; + } + } + break; + + case 'Enter': + case ' ': // Space key + // Activate the current tab + handled = true; + handleClick(event, state); + break; + } + + if (handled) { + event.preventDefault(); + event.stopPropagation(); + } + } + function selectTabs(tabIds) { for (var _i = 0, tabIds_1 = tabIds; _i < tabIds_1.length; _i++) { var tabId = tabIds_1[_i]; From db4b0f4ea51554548e613aefdeccd783b3f3af69 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Tue, 26 Aug 2025 13:28:16 -0700 Subject: [PATCH 09/22] [CPU] Improve QMoE kernel (#25822) This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme. ### Operator Registration and Type Support - Added explicit registration and support for `QMoE` operator with both `MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110) [[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277) [[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467) [[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548) ### SwiGLU Activation Improvements - Refactored `ApplySwiGLUActivation` to accept configurable `activation_alpha` and `activation_beta` parameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60) [[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61) [[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13) - Now reads `activation_alpha` and `activation_beta` attributes from operator parameters, defaulting to values appropriate for SwiGLU. ### QMoE Operator Implementation Refactor - Refactored the QMoE operator to clarify separation between quantized and FP32 implementations, and restructured internal methods for better maintainability. Added template parameterization for data types and improved handling of expert weights and biases. [[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35) [[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55) [[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59) ### Shape Checking and Layout - Removed legacy shape/layout support in QMoE input validation, enforcing only the new memory layout for expert weights and improving consistency and forward compatibility. ### Test and Documentation Updates - Updated unit tests for QMoE to use correct zero-point values for quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that test cases accurately reflect expected zero-output behavior for zero weights. Also clarified comments and expected outputs for SwiGLU and quantized scenarios. [[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349) [[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380) [[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413) [[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538) These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime. Unit test result ``` sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372 .Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392 .Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609 .Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702 . ---------------------------------------------------------------------- Ran 9 tests in 46.754s OK (skipped=1) ``` --------- Co-authored-by: Tianlei Wu --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 +- .../contrib_ops/cpu/moe/moe_base_cpu.h | 12 +- .../cpu/{quantization => moe}/moe_helper.h | 0 .../cpu/moe/moe_quantization_cpu.cc | 393 ++++++ .../cpu/moe/moe_quantization_cpu.h | 34 + onnxruntime/contrib_ops/cpu/moe/moe_utils.cc | 66 +- onnxruntime/contrib_ops/cpu/moe/moe_utils.h | 4 +- .../cpu/quantization/moe_quantization_cpu.cc | 596 --------- .../cpu/quantization/moe_quantization_cpu.h | 63 - .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 3 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 2 +- onnxruntime/test/contrib_ops/moe_test.cc | 173 +-- .../test/python/transformers/test_qmoe_cpu.py | 1097 +++++++++-------- 13 files changed, 1130 insertions(+), 1319 deletions(-) rename onnxruntime/contrib_ops/cpu/{quantization => moe}/moe_helper.h (100%) create mode 100644 onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc create mode 100644 onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 7623e2d88f3cd..34410a5f42630 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -106,7 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -272,7 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index 73e7ee6014b95..eae96c186d471 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -6,7 +6,8 @@ #include "core/common/common.h" #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/quantization/moe_helper.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include namespace onnxruntime { namespace contrib { @@ -46,12 +47,21 @@ class MoEBaseCPU { if (use_sparse_mixer_) { ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); } + + swiglu_fusion_ = op_kernel_info.GetAttrOrDefault("swiglu_fusion", 0); + swiglu_limit_ = op_kernel_info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); + activation_alpha_ = op_kernel_info.GetAttrOrDefault("activation_alpha", 1.0f); + activation_beta_ = op_kernel_info.GetAttrOrDefault("activation_beta", 0.0f); } bool normalize_routing_weights_; bool use_sparse_mixer_; int64_t k_; ActivationType activation_type_; + float activation_alpha_; + float activation_beta_; + float swiglu_limit_; + int64_t swiglu_fusion_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h similarity index 100% rename from onnxruntime/contrib_ops/cpu/quantization/moe_helper.h rename to onnxruntime/contrib_ops/cpu/moe/moe_helper.h diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc new file mode 100644 index 0000000000000..9b35a40f64f2a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_quantization_cpu.h" + +#include "core/framework/allocator.h" +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. +// The source quantized weights are stored as a row-major representation of the transposed +// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t /*block_size*/, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data) { + const float zero_point = num_bits == 8 ? 128.0f : 8.0f; + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; + // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. + const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + } + } + } +} + +template +QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) + : OpKernel(op_kernel_info), + MoEBaseCPU(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); + ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, + "Attribute 'expert_weight_bits' must be 4 or 8."); + block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); +} + +template +Status QMoECPU::Compute(OpKernelContext* context) const { + // --- 1. Get Inputs and Attributes --- + const auto* input = context->Input(0); + const auto* router_probs = context->Input(1); + const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_scales = context->Input(3); + const auto* fc1_experts_bias = context->Input(4); + const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_scales = context->Input(6); + const auto* fc2_experts_bias = context->Input(7); + const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_scales = context->Input(9); + const auto* fc3_experts_bias = context->Input(10); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, fc1_scales, + fc2_experts_weights, fc2_experts_bias, fc2_scales, + fc3_experts_weights, fc3_experts_bias, fc3_scales, + expert_weight_bits_ == 4 ? 2 : 1, + true)); + + if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); + } + + const auto& input_shape = input->Shape(); + const int64_t num_tokens = moe_params.num_rows; + const int64_t hidden_size = moe_params.hidden_size; + const int64_t inter_size = moe_params.inter_size; + const int64_t num_experts = moe_params.num_experts; + const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + + auto* output = context->Output(0, input_shape); + auto* tp = context->GetOperatorThreadPool(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + + const T* input_data = input->Data(); + const T* router_probs_data = router_probs->Data(); + + // --- 2. Routing Logic: Assign tokens to experts --- + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + const_cast(router_logits_float), + static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_probs_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + // Parallelize the routing logic to improve performance for large token batches. + // Minor performance regression for single-token decoding is an acceptable trade-off + int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. + std::vector> sorted_logits(static_cast(num_experts)); + std::vector top_k_exp(static_cast(k_)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { + sorted_logits[static_cast(j)] = {logits[j], j}; + } + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + float max_logit = -std::numeric_limits::infinity(); + for (int64_t j = 0; j < k_; ++j) { + if (sorted_logits[static_cast(j)].first > max_logit) { + max_logit = sorted_logits[static_cast(j)].first; + } + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); + sum_exp += top_k_exp[static_cast(j)]; + } + + float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; + if (route_scale[route_idx] > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + }); + + // Merge the maps from each thread into a single global map. + std::vector> expert_token_map(static_cast(num_experts)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + } + } + } + + // --- 3. Parallel Expert Computation --- + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), + const_cast(input_float), + static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data); + } + + int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + + // Pre-calculate workspace size per thread to avoid allocations inside the loop + size_t max_tokens_per_expert = 0; + for (const auto& tokens : expert_token_map) { + if (tokens.size() > max_tokens_per_expert) { + max_tokens_per_expert = tokens.size(); + } + } + + const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); + const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); + const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); + const size_t B2_dequant_size = static_cast(hidden_size * inter_size); + const size_t bias1_size = static_cast(fc1_out_features); + const size_t bias2_size = static_cast(hidden_size); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + float* workspace = workspace_ptr.get(); + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) { + continue; + } + + const int64_t num_expert_tokens = routes.size(); + + // Partition the workspace for the current expert + float* A1 = thread_workspace; + float* C1 = A1 + num_expert_tokens * hidden_size; + float* A2 = C1 + num_expert_tokens * fc1_out_features; + float* C2 = A2 + num_expert_tokens * inter_size; + float* B1_dequant = C2 + num_expert_tokens * hidden_size; + float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; + float* bias1_float = B2_dequant + hidden_size * inter_size; + float* bias2_float = bias1_float + fc1_out_features; + + // --- Gather input tokens for the current expert --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + memcpy(A1 + i * hidden_size, + input_float + token_idx * hidden_size, + static_cast(hidden_size) * sizeof(float)); + } + + // --- FC1 GEMM (X * W1^T) --- + DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), + fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), + block_size_, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), + 1.0f, A1, static_cast(hidden_size), + B1_dequant, static_cast(hidden_size), + 0.0f, C1, static_cast(fc1_out_features), + nullptr); + + const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; + if (B1_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + } else { + memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + for (int64_t i = 0; i < num_expert_tokens; ++i) { + for (int64_t j = 0; j < fc1_out_features; ++j) { + C1[i * fc1_out_features + j] += bias1_float[j]; + } + } + } + + // --- Activation --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + + // --- FC2 GEMM (A2 * W2^T) --- + DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), + fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), + block_size_, expert_weight_bits_, + hidden_size, inter_size, B2_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), + 1.0f, A2, static_cast(inter_size), + B2_dequant, static_cast(inter_size), + 0.0f, C2, static_cast(hidden_size), + nullptr); + + const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; + if (B2_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + } else { + memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + } + + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t route_idx = routes[static_cast(i)]; + const int64_t token_idx = route_idx / k_; + const float weight = route_scale[route_idx]; + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + token_idx * hidden_size; + const float* src = C2 + i * hidden_size; + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + } + } + } + }); + + // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- + auto accumulate = [&](float* buffer) { + memset(buffer, 0, output_buffer_size * sizeof(float)); + for (int i = 0; i < num_expert_threads; ++i) { + for (size_t j = 0; j < output_buffer_size; ++j) { + buffer[j] += thread_local_outputs[static_cast(i) * output_buffer_size + j]; + } + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + // --- 5. Convert final float buffer to output type T --- + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { // T is float + accumulate(output->MutableData()); + } + + return Status::OK(); +} + +// Explicit template instantiation +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; + +// Kernel Registration +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, float, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, MLFloat16, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h new file mode 100644 index 0000000000000..890580e051a8e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +/** + * @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator. + * + * This kernel supports both float and MLFloat16 data types for activations, scales, and outputs. + * It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory + * usage through on-the-fly block dequantization of weights. + * + * @tparam T The data type for the kernel (float or MLFloat16). + */ +template +class QMoECPU final : public OpKernel, public MoEBaseCPU { + public: + explicit QMoECPU(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + int64_t expert_weight_bits_; + int64_t block_size_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc index 6214b7819b765..2c59210bfabd4 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/moe/moe_utils.h" #include #include +#include "core/common/common.h" namespace onnxruntime { namespace contrib { @@ -19,74 +20,31 @@ float ApplyActivation(float x, ActivationType activation_type) { case ActivationType::Identity: return x; case ActivationType::SwiGLU: - // SwiGLU: This is handled specially as it requires gating, not applied here + // SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder return x; default: - return x; // Default to identity + return x; } } -// Helper method for applying SwiGLU activation with different memory layouts -void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) { - constexpr float swiglu_alpha = 1.702f; - constexpr float clamp_limit = 7.0f; // Clamping limit as specified - +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit) { if (is_interleaved_format) { - // For interleaved format [linear, gate, linear, gate, ...], process directly - // Make a temporary copy of each pair of values before modifying them for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - const size_t linear_idx = 2 * idx; - const size_t gate_idx = linear_idx + 1; + float gate_val = input_data[2 * i]; + float linear_val = input_data[2 * i + 1]; - // Store original values - float linear_val = data[linear_idx]; // Interleaved: even index - float gate_val = data[gate_idx]; // Interleaved: odd index + gate_val = std::min(gate_val, clamp_limit); + linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); - // Apply clamping to the values - if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only - if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max - if (linear_val < -clamp_limit) linear_val = -clamp_limit; - - // SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1) - float sigmoid_arg = swiglu_alpha * gate_val; + float sigmoid_arg = activation_alpha * gate_val; float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); float swish_out = gate_val * sigmoid_out; - float result = swish_out * (linear_val + 1.0f); - // Store result in first element (linear position) - data[idx] = result; + output_data[i] = swish_out * (linear_val + activation_beta); } } else { - // For chunked layout [linear..., gate...], handle separately - // Need to work with original data in-place - // First, store all the gate computations since they depend on original gate values - std::vector computed_gates(static_cast(inter_size)); - - for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - float gate_val = data[idx + static_cast(inter_size)]; - - // Apply clamping to the gate value (max only) - if (gate_val > clamp_limit) gate_val = clamp_limit; - - // Compute the gate part of SwiGLU - float sigmoid_arg = swiglu_alpha * gate_val; - float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); - computed_gates[idx] = gate_val * sigmoid_out; - } - - // Now apply the full activation with the precomputed gate values - for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - float linear_val = data[idx]; - - // Apply clamping to the linear value (min/max) - if (linear_val > clamp_limit) linear_val = clamp_limit; - if (linear_val < -clamp_limit) linear_val = -clamp_limit; - - data[idx] = computed_gates[idx] * (linear_val + 1.0f); - } + ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation"); } } diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h index e20dc101c7412..de238e8d7ae66 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h @@ -9,7 +9,9 @@ namespace onnxruntime { namespace contrib { float ApplyActivation(float x, ActivationType activation_type); -void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format); + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc deleted file mode 100644 index 8bd4dcf1afbab..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc +++ /dev/null @@ -1,596 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/cpu/quantization/moe_quantization_cpu.h" -#include "core/framework/allocator.h" -#include "core/framework/buffer_deleter.h" -#include "core/mlas/inc/mlas.h" -#include "core/mlas/inc/mlas_q4.h" -#include "core/mlas/inc/mlas_qnbit.h" -#include "core/platform/threadpool.h" -#include "contrib_ops/cpu/moe/moe_utils.h" -#include - -using namespace onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { - -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()) \ - .TypeConstraint("T2", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL(); - -// QMoE CPU kernel registration is handled in cpu_contrib_kernels.cc - -QMoE::QMoE(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), - MoEBaseCPU(op_kernel_info), - prepacked_fc1_weights_data_(nullptr), - prepacked_fc2_weights_data_(nullptr), - weights_allocator_(nullptr), - is_prepacked_(false) { - ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); - ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, - "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); -} - -Status QMoE::Compute(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* router_probs = context->Input(1); - const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc1_scales = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); - const Tensor* fc2_experts_weights = context->Input(5); - const Tensor* fc2_scales = context->Input(6); - const Tensor* fc2_experts_bias_optional = context->Input(7); - const Tensor* fc3_experts_weights_optional = context->Input(8); - const Tensor* fc3_scales_optional = context->Input(9); - const Tensor* fc3_experts_bias_optional = context->Input(10); - - MoEParameters moe_params; - ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( - moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, - fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, - fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, - expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ActivationType::SwiGLU)); - - // Dispatch based on input data type - if (input->IsDataType()) { - if (expert_weight_bits_ == 4) { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } else { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } - } else if (input->IsDataType()) { - if (expert_weight_bits_ == 4) { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } else { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "QMoE only supports float and MLFloat16 data types, but got ", - DataTypeImpl::ToString(input->DataType())); - } -} - -template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const { - // SwiGLU validation - FC3 not supported - bool is_swiglu = (activation_type_ == ActivationType::SwiGLU); - if (is_swiglu && fc3_experts_weights_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "SwiGLU activation is not supported with fc3."); - } - if (!is_swiglu && fc3_experts_weights_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "FC3 gating is not yet implemented on CPU."); - } - - // Check if we need to repack weights - if (!is_prepacked_ || - cached_num_experts_ != moe_params.num_experts || - cached_hidden_size_ != moe_params.hidden_size || - cached_inter_size_ != moe_params.inter_size || - cached_is_swiglu_ != is_swiglu) { - // Need to prepack weights - Status status = const_cast(this)->PrepackAndDequantizeWeights( - context, moe_params, fc1_experts_weights, fc2_experts_weights, - fc1_scales, fc2_scales, is_swiglu); - ORT_RETURN_IF_ERROR(status); - } - // Get thread pool - auto* thread_pool = context->GetOperatorThreadPool(); - - // Get input data pointers - const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - const T* fc1_bias_data = fc1_experts_bias_optional ? fc1_experts_bias_optional->Data() : nullptr; - const T* fc2_bias_data = fc2_experts_bias_optional ? fc2_experts_bias_optional->Data() : nullptr; - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - const int64_t num_threads = std::min( - static_cast(concurrency::ThreadPool::DegreeOfParallelism(thread_pool)), - moe_params.num_rows); - - const int64_t total_output_size = moe_params.num_rows * moe_params.hidden_size; - std::fill_n(output_data, total_output_size, MLFloat16(0.0f)); - - // Using prepacked weights - no need to convert scales - - auto thread_fc1_buffers = IAllocator::MakeUniquePtr(allocator, static_cast(num_threads * moe_params.inter_size * (is_swiglu ? 2 : 1))); - auto thread_fc2_buffers = IAllocator::MakeUniquePtr(allocator, static_cast(num_threads * moe_params.hidden_size)); - - // Set up output buffer - IAllocatorUniquePtr output_float; - float* output_float_ptr = nullptr; - - if constexpr (std::is_same_v) { - // For MLFloat16, we need a separate float buffer - output_float = IAllocator::MakeUniquePtr(allocator, static_cast(total_output_size)); - output_float_ptr = output_float.get(); - } else { - // For float, we can write directly to output_data - output_float = IAllocatorUniquePtr(output_data, [](float*) {}); - output_float_ptr = output_data; - } - - // Initialize output to zeros - std::fill_n(output_float_ptr, total_output_size, 0.0f); - - // Prepare float buffers for input data and biases - IAllocatorUniquePtr input_float; - IAllocatorUniquePtr router_probs_float; - - // Pointers for easier access - float* input_float_ptr = nullptr; - float* router_probs_float_ptr = nullptr; - - // Pre-convert bias tensors to float (if they exist) - const int64_t fc1_bias_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - const int64_t fc2_bias_size = moe_params.hidden_size; - - // Allocate buffers for converted biases using ORT allocator - IAllocatorUniquePtr fc1_bias_float; - IAllocatorUniquePtr fc2_bias_float; - - if (fc1_bias_data) { - fc1_bias_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_experts * fc1_bias_size)); - } - - if (fc2_bias_data) { - fc2_bias_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_experts * fc2_bias_size)); - } - - // Convert input and router_probs based on type - if constexpr (std::is_same_v) { - // For MLFloat16, convert to float - need to allocate buffers first - input_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_rows * moe_params.hidden_size)); - router_probs_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_rows * moe_params.num_experts)); - - input_float_ptr = input_float.get(); - router_probs_float_ptr = router_probs_float.get(); - - // Convert MLFloat16 to float - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(input_data), - input_float_ptr, - static_cast(moe_params.num_rows * moe_params.hidden_size), - thread_pool); - - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(router_probs_data), - router_probs_float_ptr, - static_cast(moe_params.num_rows * moe_params.num_experts), - thread_pool); - - // Convert biases to float once (if they exist) - if (fc1_bias_data) { - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc1_bias_data), - fc1_bias_float.get(), - static_cast(moe_params.num_experts * fc1_bias_size), - thread_pool); - } - - if (fc2_bias_data) { - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc2_bias_data), - fc2_bias_float.get(), - static_cast(moe_params.num_experts * fc2_bias_size), - thread_pool); - } - } else { - // For float, point to original input and router_probs directly instead of copying - input_float = IAllocatorUniquePtr(const_cast(input_data), [](float*) {}); - router_probs_float = IAllocatorUniquePtr(const_cast(router_probs_data), [](float*) {}); - - // Set pointers to the original data - input_float_ptr = const_cast(input_data); - router_probs_float_ptr = const_cast(router_probs_data); - - // For float, just point to the original bias data directly without copying - // No need to allocate or copy, just reuse the original pointers - if (fc1_bias_data) { - // Release previously allocated memory if any - fc1_bias_float.reset(); - // Direct pointer to original data - fc1_bias_float = IAllocatorUniquePtr(const_cast(fc1_bias_data), [](float*) {}); - } - - if (fc2_bias_data) { - // Release previously allocated memory if any - fc2_bias_float.reset(); - // Direct pointer to original data - fc2_bias_float = IAllocatorUniquePtr(const_cast(fc2_bias_data), [](float*) {}); - } - } - - // No need to initialize thread results - using direct output buffer - - // Determine activation related parameters - const bool is_4bit = UseUInt4x2; - const int64_t act_multiplier = is_swiglu ? 2 : 1; - const int64_t fc1_output_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - - // Use prepacked dequantized weights - no need to dequantize here - const float* dequant_fc1_weights = prepacked_fc1_weights_data_; - const float* dequant_fc2_weights = prepacked_fc2_weights_data_; - - // Process tokens in parallel - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_rows), - static_cast(std::max(1, moe_params.num_rows / num_threads)), - [&](ptrdiff_t start_token, ptrdiff_t end_token) { - const int64_t thread_id = start_token / ((moe_params.num_rows + num_threads - 1) / num_threads); - const int64_t thread_fc1_size = is_4bit ? (moe_params.inter_size * (is_swiglu ? 2 : 1)) : (moe_params.inter_size * act_multiplier); - float* thread_fc1_output = thread_fc1_buffers.get() + thread_id * thread_fc1_size; - float* thread_fc2_output = thread_fc2_buffers.get() + thread_id * moe_params.hidden_size; - - // Process each token in this thread's range - for (std::ptrdiff_t token_idx = start_token; token_idx < end_token; ++token_idx) { - const float* token_input = input_float_ptr + static_cast(SafeInt(token_idx)) * moe_params.hidden_size; - float* token_result = output_float_ptr + static_cast(SafeInt(token_idx)) * moe_params.hidden_size; - - // Process all experts for this token - for (std::ptrdiff_t expert_idx = 0; expert_idx < moe_params.num_experts; ++expert_idx) { - float routing_weight = router_probs_float_ptr[static_cast(SafeInt(token_idx)) * moe_params.num_experts + static_cast(SafeInt(expert_idx))]; - if (routing_weight <= 1e-6f) continue; // Skip experts with negligible routing weight - - // FC1: input -> intermediate using pre-dequantized weights + MLAS SGEMM - const int64_t fc1_weight_offset = is_4bit ? (static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * fc1_output_size) : (static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * moe_params.inter_size * act_multiplier); - const float* fc1_expert_weights = dequant_fc1_weights + fc1_weight_offset; - - // Bias size is always equal to output size (fc1_output_size), regardless of bit width - const int64_t fc1_bias_size = fc1_output_size; - - // Use MLAS SGEMM for FC1 - MLAS_SGEMM_DATA_PARAMS fc1_params; - fc1_params.A = token_input; - fc1_params.lda = static_cast(moe_params.hidden_size); - fc1_params.B = fc1_expert_weights; - fc1_params.ldb = static_cast(moe_params.hidden_size); - fc1_params.C = thread_fc1_output; - fc1_params.ldc = static_cast(fc1_bias_size); - fc1_params.alpha = 1.0f; - fc1_params.beta = 0.0f; - - MlasGemm(CblasNoTrans, CblasNoTrans, 1, static_cast(fc1_bias_size), static_cast(moe_params.hidden_size), fc1_params, nullptr); - - // Handle different activation types - if (is_swiglu) { - // Add bias if present - if (fc1_bias_data) { - // Use the pre-converted float bias data - const float* fc1_expert_bias_float = fc1_bias_float.get() + static_cast(SafeInt(expert_idx)) * fc1_bias_size; - for (int64_t i = 0; i < fc1_bias_size; ++i) { - thread_fc1_output[i] += fc1_expert_bias_float[i]; - } - } - contrib::ApplySwiGLUActivation(thread_fc1_output, moe_params.inter_size, is_4bit); - } else { - // Standard activation (non-SwiGLU) - if (fc1_bias_data) { - // Use the pre-converted float bias data - const float* fc1_expert_bias_float = fc1_bias_float.get() + static_cast(SafeInt(expert_idx)) * moe_params.inter_size; - for (int64_t i = 0; i < moe_params.inter_size; ++i) { - thread_fc1_output[i] += fc1_expert_bias_float[i]; - thread_fc1_output[i] = ApplyActivation(thread_fc1_output[i], activation_type_); - } - } else { - for (int64_t i = 0; i < moe_params.inter_size; ++i) { - thread_fc1_output[i] = ApplyActivation(thread_fc1_output[i], activation_type_); - } - } - } - - // FC2: intermediate -> output using pre-dequantized weights + MLAS SGEMM - const float* fc2_expert_weights = dequant_fc2_weights + static_cast(SafeInt(expert_idx)) * moe_params.inter_size * moe_params.hidden_size; - - // Use MLAS SGEMM for FC2 - MLAS_SGEMM_DATA_PARAMS fc2_params; - fc2_params.A = thread_fc1_output; - fc2_params.lda = static_cast(moe_params.inter_size); - fc2_params.B = fc2_expert_weights; - fc2_params.ldb = static_cast(moe_params.inter_size); - fc2_params.C = thread_fc2_output; - fc2_params.ldc = static_cast(moe_params.hidden_size); - fc2_params.alpha = 1.0f; - fc2_params.beta = 0.0f; - - MlasGemm(CblasNoTrans, CblasNoTrans, 1, static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), fc2_params, nullptr); - - // Add bias, apply routing weight, and accumulate to final result - if (fc2_bias_data) { - // Use the pre-converted float bias data - const float* fc2_expert_bias_float = fc2_bias_float.get() + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size; - for (int64_t i = 0; i < moe_params.hidden_size; ++i) { - token_result[i] += routing_weight * (thread_fc2_output[i] + fc2_expert_bias_float[i]); - } - } else { - for (int64_t i = 0; i < moe_params.hidden_size; ++i) { - token_result[i] += routing_weight * thread_fc2_output[i]; - } - } - } - } - }); - - // No need for accumulation since threads write directly to output_float - - // Convert results back to the appropriate output type, if needed - if constexpr (std::is_same_v) { - // For MLFloat16, convert from float to half - MlasConvertFloatToHalfBuffer(output_float_ptr, reinterpret_cast(output_data), static_cast(total_output_size)); - } - // For float, no conversion needed as we directly wrote to output_data - - // Suppress unused parameter warnings for optional parameters that are not used in non-SwiGLU modes - if (!is_swiglu) { - ORT_UNUSED_PARAMETER(fc3_experts_bias_optional); - ORT_UNUSED_PARAMETER(fc3_scales_optional); - } - - return Status::OK(); -} - -template -Status QMoE::PrepackAndDequantizeWeights(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - bool is_swiglu) { - // Get thread pool - auto* thread_pool = context->GetOperatorThreadPool(); - - // Get input data pointers - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); - const void* fc1_scales_data_typed = fc1_scales->DataRaw(); - const void* fc2_scales_data_typed = fc2_scales->DataRaw(); - bool is_fp32_scales = fc1_scales->IsDataType(); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - const int64_t num_threads = std::min( - static_cast(concurrency::ThreadPool::DegreeOfParallelism(thread_pool)), - moe_params.num_experts); - - // Prepare scales in float format - const int64_t fc1_scales_size = moe_params.num_experts * (is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size); - const int64_t fc2_scales_size = moe_params.num_experts * moe_params.hidden_size; - - auto fc1_scales_float = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_scales_size)); - auto fc2_scales_float = IAllocator::MakeUniquePtr(allocator, static_cast(fc2_scales_size)); - - if (is_fp32_scales) { - // For float scales, just copy - std::memcpy(fc1_scales_float.get(), fc1_scales_data_typed, static_cast(fc1_scales_size) * sizeof(float)); - std::memcpy(fc2_scales_float.get(), fc2_scales_data_typed, static_cast(fc2_scales_size) * sizeof(float)); - } else { - // For MLFloat16 scales, convert to float - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc1_scales_data_typed), - fc1_scales_float.get(), - static_cast(fc1_scales_size), - thread_pool); - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc2_scales_data_typed), - fc2_scales_float.get(), - static_cast(fc2_scales_size), - thread_pool); - } - - const float* fc1_scales_data = fc1_scales_float.get(); - const float* fc2_scales_data = fc2_scales_float.get(); - - // Determine quantization parameters based on bit width - using symmetric quantization for TensorRT compatibility - const bool is_4bit = UseUInt4x2; - const int64_t act_multiplier = is_swiglu ? 2 : 1; - const int64_t fc1_output_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - - // Calculate weight sizes and strides based on quantization type - const int64_t fc1_weight_stride = is_4bit ? (moe_params.hidden_size * fc1_output_size / 2) : (moe_params.hidden_size * moe_params.inter_size * act_multiplier); - const int64_t fc2_weight_stride = is_4bit ? (moe_params.inter_size * moe_params.hidden_size / 2) : (moe_params.inter_size * moe_params.hidden_size); - - // Get or create a persistent allocator for weights - if (weights_allocator_ == nullptr) { - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&weights_allocator_)); - } - - // Allocate prepacked weight buffers using ORT allocator - const size_t fc1_weights_size = static_cast(moe_params.num_experts * moe_params.hidden_size * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier)); - const size_t fc2_weights_size = static_cast(moe_params.num_experts * moe_params.inter_size * moe_params.hidden_size); - - prepacked_fc1_weights_ = IAllocator::MakeUniquePtr(weights_allocator_, fc1_weights_size); - prepacked_fc2_weights_ = IAllocator::MakeUniquePtr(weights_allocator_, fc2_weights_size); - - // Store pointers for easy access - prepacked_fc1_weights_data_ = prepacked_fc1_weights_.get(); - prepacked_fc2_weights_data_ = prepacked_fc2_weights_.get(); - - // Helper lambda for dequantizing a single weight value - updated for symmetric quantization - auto DequantizeWeight = [&](const uint8_t* weights, size_t linear_idx, - const float* scales, int64_t scale_idx) -> float { - if (is_4bit) { - // For Int4, two values are packed in each uint8 - size_t packed_idx = linear_idx / 2; - uint8_t packed_value = weights[packed_idx]; - uint8_t quantized_weight = (linear_idx % 2 == 0) ? (packed_value & 0x0F) : ((packed_value >> 4) & 0x0F); - // Convert uint4 to int4 with proper mapping for symmetric quantization - int8_t signed_weight = static_cast(quantized_weight); - if (signed_weight >= 8) { - signed_weight -= 16; // Map [8, 15] to [-8, -1] for proper signed representation - } - return static_cast(signed_weight) * scales[scale_idx]; - } else { - // For Int8, convert uint8 to int8 for symmetric quantization - int8_t signed_weight = static_cast(weights[linear_idx]); - return static_cast(signed_weight) * scales[scale_idx]; - } - }; - - // Dequantize FC1 weights for all experts - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_experts), - static_cast(std::max(1, moe_params.num_experts / num_threads)), - [&](ptrdiff_t expert_start, ptrdiff_t expert_end) { - for (std::ptrdiff_t expert_idx = expert_start; expert_idx < expert_end; ++expert_idx) { - const uint8_t* fc1_expert_weights = fc1_weights_data + static_cast(SafeInt(expert_idx)) * fc1_weight_stride; - const float* fc1_expert_scales = fc1_scales_data + static_cast(SafeInt(expert_idx)) * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier); - float* dequant_fc1_expert = prepacked_fc1_weights_data_ + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier); - - const int64_t output_cols = is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier; - for (int64_t out_col = 0; out_col < output_cols; ++out_col) { - for (int64_t in_col = 0; in_col < moe_params.hidden_size; ++in_col) { - size_t linear_idx = static_cast(out_col * moe_params.hidden_size + in_col); - dequant_fc1_expert[linear_idx] = DequantizeWeight(fc1_expert_weights, linear_idx, fc1_expert_scales, out_col); - } - } - } - }); - - // Dequantize FC2 weights for all experts - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_experts), - static_cast(std::max(1, moe_params.num_experts / num_threads)), - [&](ptrdiff_t expert_start, ptrdiff_t expert_end) { - for (std::ptrdiff_t expert_idx = expert_start; expert_idx < expert_end; ++expert_idx) { - const uint8_t* fc2_expert_weights = fc2_weights_data + static_cast(SafeInt(expert_idx)) * fc2_weight_stride; - const float* fc2_expert_scales = fc2_scales_data + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size; - float* dequant_fc2_expert = prepacked_fc2_weights_data_ + static_cast(SafeInt(expert_idx)) * moe_params.inter_size * moe_params.hidden_size; - - for (int64_t out_col = 0; out_col < moe_params.hidden_size; ++out_col) { - for (int64_t in_col = 0; in_col < moe_params.inter_size; ++in_col) { - size_t linear_idx = static_cast(out_col * moe_params.inter_size + in_col); - dequant_fc2_expert[linear_idx] = DequantizeWeight(fc2_expert_weights, linear_idx, fc2_expert_scales, out_col); - } - } - } - }); - - // Update cached parameters - cached_num_experts_ = moe_params.num_experts; - cached_hidden_size_ = moe_params.hidden_size; - cached_inter_size_ = moe_params.inter_size; - cached_is_swiglu_ = is_swiglu; - is_prepacked_ = true; - - return Status::OK(); -} - -// Explicit template instantiations -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h deleted file mode 100644 index 19caa86c0fd98..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/cpu/moe/moe_base_cpu.h" -#include "core/common/common.h" -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace contrib { - -class QMoE final : public OpKernel, public MoEBaseCPU { - public: - explicit QMoE(const OpKernelInfo& op_kernel_info); - Status Compute(OpKernelContext* ctx) const override; - - private: - template - Status PrepackAndDequantizeWeights(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - bool is_swiglu); - - template - Status QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - - // Prepacked dequantized weights stored for reuse - IAllocatorUniquePtr prepacked_fc1_weights_; - IAllocatorUniquePtr prepacked_fc2_weights_; - float* prepacked_fc1_weights_data_{nullptr}; - float* prepacked_fc2_weights_data_{nullptr}; - - // Persistent allocator for weights - AllocatorPtr weights_allocator_; - - // Cached parameters to detect changes requiring repack - mutable int64_t cached_num_experts_{0}; - mutable int64_t cached_hidden_size_{0}; - mutable int64_t cached_inter_size_{0}; - mutable bool cached_is_swiglu_{false}; - mutable bool is_prepacked_{false}; - - int64_t expert_weight_bits_; -}; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 2c6a28f1c55f4..ce8c0270f5c32 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -540,7 +540,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + float old_val = static_cast(output[idx - ki]); + output[idx - ki] = T(old_val / output_row_sum); } } } diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 3ca7cee46b22b..5f0c30b16a8f4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -7,7 +7,7 @@ #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" #include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" -#include "contrib_ops/cpu/quantization/moe_helper.h" +#include "contrib_ops/cpu/moe/moe_helper.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 05fff886080d7..ed7ca998e0b86 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -9,13 +9,16 @@ namespace onnxruntime { namespace test { +// Note: QMoE CPU implementation now always applies softmax normalization to top-k selected experts +// regardless of the normalize_routing_weights parameter value for mathematical correctness. + #ifndef ENABLE_TRAINING static void RunMoETest(const std::vector& input, const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, int inter_size, std::string activation_type, - int normalize_routing_weights = 0, int top_k = 1, bool use_float16 = false) { + int normalize_routing_weights = 1, int top_k = 1, bool use_float16 = false) { constexpr int min_cuda_arch = 700; bool enable_cuda = HasCudaEnvironment(min_cuda_arch); @@ -27,8 +30,8 @@ static void RunMoETest(const std::vector& input, const std::vector std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_experts_bias_dims = {num_experts, inter_size}; std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; @@ -90,7 +93,7 @@ static void RunQMoETest(const std::vector& input, const std::vector& fc3_experts_weights, const std::vector& fc1_scales, const std::vector& fc2_scales, const std::vector& fc3_scales, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, - int inter_size, std::string activation_type, int normalize_routing_weights = 0, int top_k = 1, int expert_weight_bits = 4) { + int inter_size, std::string activation_type, int normalize_routing_weights = 1, int top_k = 1, int expert_weight_bits = 4) { constexpr int min_cuda_arch = 700; // Test CUDA execution provider @@ -103,7 +106,6 @@ static void RunQMoETest(const std::vector& input, const std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - // Adjust weight dimensions based on quantization type for CUDA as well std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; @@ -150,7 +152,6 @@ static void RunQMoETest(const std::vector& input, const std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - // Adjust weight dimensions based on quantization type std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc1_scales_dims = {num_experts, inter_size}; @@ -355,7 +356,7 @@ TEST(MoETest, MoETest_Gelu) { 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "gelu"); + output, num_rows, num_experts, hidden_size, inter_size, "gelu", 0); } TEST(MoETest, MoETest_Relu) { @@ -533,7 +534,7 @@ TEST(MoETest, MoETest_Relu) { 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "relu"); + output, num_rows, num_experts, hidden_size, inter_size, "relu", 0); } TEST(MoETest, MoETest_Mixtral) { @@ -1322,7 +1323,6 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { // CPU-specific QMoE tests TEST(MoETest, QMoETest_CPU_Int4_MLAS) { - // Test CPU implementation with 4-bit quantization (MLAS optimized path) - CPU only int num_rows = 2; int num_experts = 2; int hidden_size = 32; @@ -1336,31 +1336,32 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { const std::vector router_probs = {0.3f, 0.7f, 0.6f, 0.4f}; - // Generate simple test weights for 4-bit symmetric quantization - // Use 0x00 which unpacks to 0,0 (both 0 for 4-bit) - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size / 2, 0x00); - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x00); // 0,0 values to produce zero output + // Generate simple test weights for 4-bit symmetric quantization with SwiGLU + // Use 0x88 which unpacks to 8,8 -> 0,0 in signed form (8-8=0) for zero weights + // For SwiGLU: FC1 outputs 2*inter_size (gate + linear), FC2 takes inter_size input + std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 0x88); // 2*inter_size for SwiGLU, packed into /2 + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x88); // 8,8 values to produce zero output std::vector fc3_experts_weights; // Empty for CPU (FC3 not supported) - std::vector fc1_scales(num_experts * inter_size, 0.01f); // Smaller scale factor - std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor + std::vector fc1_scales(num_experts * inter_size * 2, 0.01f); // 2x for SwiGLU (gate + linear) + std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor std::vector fc3_scales; - // With zero weights (0x00), the current implementation will produce all zero outputs + // With zero weights (0x88 -> 8,8 -> 0,0 signed), the implementation will produce all zero outputs std::vector output(num_rows * hidden_size, 0.0f); // Test CPU execution provider ONLY (don't use RunQMoETest which tests both CUDA and CPU) OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 2); - cpu_tester.AddAttribute("activation_type", "gelu"); - cpu_tester.AddAttribute("normalize_routing_weights", 1); - cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; // SwiGLU: 2*inter_size output, 4-bit packed std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1376,8 +1377,8 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float for CPU) cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias - // When using 0x00 for 4-bit quantized weights with the current implementation, - // all dequantized values should be 0.0f, and thus output should be all zeros + // When using 0x88 for 4-bit quantized weights with the current implementation, + // all dequantized values should be 0.0f (8-8=0), and thus output should be all zeros std::vector expected_output(num_rows * hidden_size, 0.0f); cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output)); @@ -1400,31 +1401,31 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) { const std::vector router_probs = {0.4f, 0.6f}; - // For 8-bit symmetric quantization, dimensions don't need /2 - // Use quantized weights close to zero for reasonable dequantization - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 2); // 2 = small positive value - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 254); // 254 = -2 in 8-bit signed - std::vector fc3_experts_weights; // Empty for CPU + // For 8-bit symmetric quantization with SwiGLU + // Use quantized weights at zero point for zero outputs (128 = 0 in signed) + std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128); // 2*inter_size for SwiGLU, no packing for 8-bit + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128); // 128 = 0 in signed + std::vector fc3_experts_weights; // Empty for CPU - std::vector fc1_scales(num_experts * inter_size, 0.1f); + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); // 2x for SwiGLU std::vector fc2_scales(num_experts * hidden_size, 0.1f); std::vector fc3_scales; - // Expected output should be close to zero since we're using small weights around zero point + // Expected output should be zero since we're using zero weights (128-128=0) std::vector output(num_rows * hidden_size, 0.0f); // Test with different attributes for 8-bit OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 1); - cpu_tester.AddAttribute("activation_type", "relu"); - cpu_tester.AddAttribute("normalize_routing_weights", 0); - cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // SwiGLU: 2*inter_size output, 8-bit no packing std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1457,27 +1458,30 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) { const std::vector input = {0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f}; const std::vector router_probs = {0.5f, 0.5f}; - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size / 2, 0x01); // 0,1 in symmetric quantization - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x10); // 1,0 in symmetric quantization - std::vector fc3_experts_weights(num_experts * hidden_size * inter_size / 2, 0x21); // 2,1 in symmetric quantization, FC3 provided + // Using new layout: fc1 has fused swiglu doubling (2*inter_size) and 4-bit pack_size=2 so hidden_size packed dimension is hidden_size/2 + const int pack_size = 2; // for 4-bit + const int fc1_inter_size = 2 * inter_size; // swiglu fused + std::vector fc1_experts_weights(num_experts * fc1_inter_size * (hidden_size / pack_size), 0x01); + std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0x10); + std::vector fc3_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0x21); // FC3 provided - std::vector fc1_scales(num_experts * inter_size, 0.1f); + std::vector fc1_scales(num_experts * fc1_inter_size, 0.1f); std::vector fc2_scales(num_experts * hidden_size, 0.05f); std::vector fc3_scales(num_experts * inter_size, 0.08f); // FC3 scales provided // Test CPU execution provider ONLY (designed to test CPU-specific error handling) OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 1); - cpu_tester.AddAttribute("activation_type", "relu"); - cpu_tester.AddAttribute("normalize_routing_weights", 0); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Use 1 for consistency, though this test focuses on FC3 error cpu_tester.AddAttribute("expert_weight_bits", 4); std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; - std::vector fc3_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_experts_weights_dims = {num_experts, fc1_inter_size, hidden_size / pack_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + std::vector fc3_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + std::vector fc1_scales_dims = {num_experts, fc1_inter_size}; std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector fc3_scales_dims = {num_experts, inter_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1522,9 +1526,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { const int fc1_weight_size_per_expert = hidden_size * inter_size * 2 / 2; // For 4-bit SwiGLU const int fc2_weight_size_per_expert = inter_size * hidden_size / 2; // For 4-bit FC2 - // Generate test weights for symmetric quantization (zero point is 0) - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x12); // 1,2 -> small positive weights - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0xFF); // -1,0 -> small mixed weights + // Generate test weights for symmetric quantization (zero point is 8 for 4-bit) + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights std::vector fc3_experts_weights; // Empty for SwiGLU (gate weights concatenated with FC1) // Scales: for SwiGLU, FC1 has 2*inter_size outputs (linear + gate) @@ -1532,7 +1536,10 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { std::vector fc2_scales(num_experts * hidden_size, 0.05f); std::vector fc3_scales; - // Expected output should be small but non-zero due to SwiGLU nonlinearity + // For SwiGLU with zero weights (0x88 -> 8,8 -> 0,0 signed): + // Gate output = 0, Linear output = 0 + // SwiGLU = gate * sigmoid(gate) * (linear + 1) = 0 * sigmoid(0) * (0 + 1) = 0 * 0.5 * 1 = 0 + // So output should be zero std::vector output(num_rows * hidden_size, 0.0f); OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); @@ -1582,10 +1589,11 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { const int fc1_weight_size_per_expert = hidden_size * inter_size * 2; // For 8-bit SwiGLU const int fc2_weight_size_per_expert = inter_size * hidden_size; // For 8-bit FC2 - // Generate test weights at zero (for symmetric quantization) to produce zero output - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0); // Zero in symmetric quantization - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0); // Zero in symmetric quantization - std::vector fc3_experts_weights; // Empty for SwiGLU + // Generate test weights at zero (for symmetric quantization storage format: uint8 with zero point 128) + // Fill with 128 so dequantized value (val - 128) == 0 => zero output + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); + std::vector fc3_experts_weights; // Empty for SwiGLU // Scales: for SwiGLU, FC1 has 2*inter_size outputs std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); @@ -1627,65 +1635,6 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); } -// Test for Float32 input and output type with QMoE operator -TEST(MoETest, QMoETest_CPU_Float32) { - // Test CPU implementation with float32 input/output - int num_rows = 1; - int num_experts = 2; - int hidden_size = 8; - int inter_size = 8; - - const std::vector input = {0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f}; - const std::vector router_probs = {0.0f, 0.0f}; - - // For 8-bit quantization weights - const int fc1_weight_size_per_expert = hidden_size * inter_size; - const int fc2_weight_size_per_expert = inter_size * hidden_size; - - // Generate test weights at zero point (128 for 8-bit) to produce zero output - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); - - // Scales - std::vector fc1_scales(num_experts * inter_size, 0.1f); - std::vector fc2_scales(num_experts * hidden_size, 0.1f); - - std::vector output(num_rows * hidden_size, 0.0f); - - OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); - cpu_tester.AddAttribute("k", 2); - cpu_tester.AddAttribute("activation_type", "gelu"); - cpu_tester.AddAttribute("normalize_routing_weights", 1); - cpu_tester.AddAttribute("expert_weight_bits", 8); - - std::vector input_dims = {num_rows, hidden_size}; - std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; - std::vector fc1_scales_dims = {num_experts, inter_size}; - std::vector fc2_scales_dims = {num_experts, hidden_size}; - std::vector output_dims = {num_rows, hidden_size}; - - // Use float directly instead of MLFloat16 - cpu_tester.AddInput("input", input_dims, input); - cpu_tester.AddInput("router_probs", router_probs_dims, router_probs); - cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); - cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); - cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias - cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); - cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); - cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias - cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights - cpu_tester.AddOptionalInputEdge(); // fc3_scales - cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias - cpu_tester.AddOutput("output", output_dims, output); - cpu_tester.SetOutputTolerance(0.02f); - - std::vector> cpu_execution_providers; - cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); - cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -} - #endif } // namespace test diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 909795e6639bb..efaaca29a01b6 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -10,28 +10,26 @@ # license information. # -------------------------------------------------------------------------- # -# Note on QMoE quantization approaches: +# QMoE quantization implementation notes: # -# Both CPU and CUDA implementations of QMoE use symmetric quantization: +# Both CPU and CUDA implementations use symmetric quantization centered around 0: +# - 4-bit: range [-8, 7] with no zero-point (symmetric around 0) +# - 8-bit: range [-128, 127] with no zero-point (symmetric around 0) # -# 1. CPU (this file): Symmetric quantization -# - 4-bit: range = [-8, 7] -# - 8-bit: range = [-128, 127] +# This follows the _symmetric_quantize_last_axis_of_batched_matrix pattern. +# Tolerance values account for numerical differences between implementations. # -# 2. CUDA: Symmetric quantization -# - 4-bit: range = [-8, 7] -# - 8-bit: range = [-128, 127] -# -# This aligned approach ensures better compatibility with TensorRT. -# The tolerance values used in testing account for minor numerical differences. +# Routing Logic: CPU implementation uses top-k selection first, then softmax +# normalization on the selected experts. This provides proper weight distribution +# while maintaining computational efficiency. # -------------------------------------------------------------------------- -import itertools -import os +import time import unittest from collections import OrderedDict import numpy import torch +import torch.nn.functional as F from onnx import helper from parameterized import parameterized from torch import nn @@ -41,27 +39,41 @@ try: from onnx import TensorProto - HAS_ONNX = True + has_onnx = True except ImportError: - print("ONNX is not installed. Some functionality will not be available.") - HAS_ONNX = False + has_onnx = False + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, + "gelu": nn.GELU, +} +ACT2FN = ClassInstantier(ACT2CLS) + +if not has_onnx: - # Define placeholder constants if onnx is not available class TensorProtoPlaceholder: FLOAT16 = 10 FLOAT = 1 - # BF16 not supported in QMoE CPU UINT8 = 2 TensorProto = TensorProtoPlaceholder -# Reduces number of tests to run for faster pipeline checks -pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" - onnxruntime.preload_dlls() -# Force CPU execution provider regardless of CUDA availability device = torch.device("cpu") + ort_provider = ["CPUExecutionProvider"] torch.manual_seed(42) @@ -70,7 +82,6 @@ class TensorProtoPlaceholder: onnx_to_torch_type_map = { TensorProto.FLOAT16: torch.float16, TensorProto.FLOAT: torch.float, - # BF16 not supported in QMoE CPU TensorProto.UINT8: torch.uint8, } @@ -83,24 +94,17 @@ class TensorProtoPlaceholder: ort_dtype_name_map = { TensorProto.FLOAT16: "FP16", TensorProto.FLOAT: "FP32", - # QMoE CPU does not support BF16 } def quant_dequant(weights, is_4_bit_quantization: bool = True): """ Quantize and dequantize weights for testing purposes. - This function exactly matches the C++ implementation in QMoE CPU. - - This uses symmetric quantization to match the C++ implementation and for TensorRT compatibility: - - 4-bit: range = [-8, 7] - - 8-bit: range = [-128, 127] + This function uses symmetric quantization centered around 0 (no zero-point). - This implementation aims to precisely match the C++ implementation by: - 1. Using symmetric quantization (zero point = 0) - 2. Using the same scale calculation methodology - 3. Using consistent rounding behavior - 4. Properly handling edge cases + This uses symmetric quantization similar to _symmetric_quantize_last_axis_of_batched_matrix: + - 4-bit: range = [-8, 7], no zero-point (symmetric around 0) + - 8-bit: range = [-128, 127], no zero-point (symmetric around 0) """ # Handle edge case of all-zero weights tensor if torch.all(weights == 0): @@ -122,119 +126,107 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): torch.zeros_like(weights), ) - # Get absolute maximum for scale calculation + # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: - # For 4-bit symmetric quantization, range is [-8, 7] - scale = abs_max / 7.0 # Scale factor ensures max value maps to 7 + # 4-bit: scale = abs_max / 7.0 (using 7.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 7.0).float() + 1e-12 # Handle potential edge cases for zero or very small weights - if torch.max(abs_max) < 1e-10: - # For extremely small values, avoid division by near-zero + if torch.max(abs_max) < 1e-8: packed_size = (weights.shape[-1] + 1) // 2 - # Just return zeros with appropriate scale to avoid numerical issues return ( - torch.ones_like(weights[..., 0:1]) * 1e-6, # Very small non-zero scale - torch.full( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros( (weights.shape[0], weights.shape[1], packed_size), - fill_value=8 | (8 << 4), # 8 = 0 in symmetric quantization dtype=torch.uint8, device=weights.device, ), torch.zeros_like(weights), ) - # Convert to int4 range (-8 to 7) - scaled_weights = torch.round(weights / scale) - clipped_weights = torch.clamp(scaled_weights, -8, 7) + # Quantize: round(weight / scale) then clamp to [-8, 7] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-8, 7).float() - # Convert from int4 signed range [-8,7] to uint4 storage range [0,15] - # by adding 8 to map -8->0, -7->1, ..., 7->15 - quant_weights = (clipped_weights + 8).to(torch.uint8) + # For symmetric quantization, we use signed int4 representation + # Convert to uint8 storage for packing: shift [-8,7] -> [0,15] for storage only + storage_weights = (quantized_weights + 8).to(torch.uint8) # Pack 4-bit values into uint8 (every two elements) even_indices = torch.arange(0, weights.shape[-1], 2) odd_indices = torch.arange(1, weights.shape[-1], 2) - # Handle odd length by padding + # Handle odd length by padding with zero (which is 8 in storage representation) if odd_indices.shape[0] < even_indices.shape[0]: - # Pad with 8 (which represents 0 in symmetric quantization) - # Create a new padding tensor for more predictable behavior padding = torch.full( - (quant_weights.shape[0], quant_weights.shape[1], 1), - fill_value=8, + (storage_weights.shape[0], storage_weights.shape[1], 1), + fill_value=8, # 0 in symmetric quantization, stored as 8 dtype=torch.uint8, - device=quant_weights.device, + device=storage_weights.device, ) - quant_weights = torch.cat([quant_weights, padding], dim=-1) - odd_indices = torch.arange(1, quant_weights.shape[-1], 2) + storage_weights = torch.cat([storage_weights, padding], dim=-1) + odd_indices = torch.arange(1, storage_weights.shape[-1], 2) - even_weights = quant_weights[..., even_indices] - odd_weights = quant_weights[..., odd_indices] + even_weights = storage_weights[..., even_indices] + odd_weights = storage_weights[..., odd_indices] - # Pack two 4-bit values into each byte + # Pack: low nibble = even, high nibble = odd packed_weights = (even_weights & 0xF) | ((odd_weights & 0xF) << 4) - # For dequantization, unpack + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Unpack for dequantization lower = packed_weights & 0xF upper = (packed_weights >> 4) & 0xF - # Restore original shape, taking care to handle dimensions correctly + # Restore original shape and convert back to signed representation unpacked_weights = torch.zeros_like(weights, dtype=torch.uint8) - - # Assign values ensuring we don't go out of bounds unpacked_weights[..., even_indices] = lower - # Calculate valid odd indices that fit within our original tensor dimensions valid_odd_length = min(odd_indices.shape[0], weights.shape[-1] - even_indices.shape[0]) - valid_odd_indices = odd_indices[:valid_odd_length] - - # Only assign upper bits to valid positions if valid_odd_length > 0: + valid_odd_indices = odd_indices[:valid_odd_length] unpacked_weights[..., valid_odd_indices] = upper[..., :valid_odd_length] - # Convert back from uint4 to int4 by subtracting 8 - int4_weights = unpacked_weights.float() - 8 - - # Dequantize with proper broadcasting - # Make sure scale has the right shape for broadcasting - scale_expanded = scale.float() - if scale_expanded.dim() < int4_weights.dim(): - for _ in range(int4_weights.dim() - scale_expanded.dim()): - scale_expanded = scale_expanded.unsqueeze(-1) - result = (int4_weights * scale_expanded).to(dtype=weights.dtype) - return scale.to(torch.float16), packed_weights, result + # Convert back to signed values: [0,15] -> [-8,7] and apply scale + signed_weights = unpacked_weights.float() - 8.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), packed_weights, result.to(weights.dtype) else: - # 8-bit symmetric quantization, range is [-128, 127] - scale = abs_max / 127.0 # Scale factor ensures max value maps to 127 + # 8-bit: scale = abs_max / 127.0 (using 127.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 127.0).float() + 1e-12 # Handle potential edge cases for zero or very small weights - if torch.max(abs_max) < 1e-10: - # For extremely small values, avoid division by near-zero - # Just return zeros with appropriate scale to avoid numerical issues + if torch.max(abs_max) < 1e-8: return ( - torch.ones_like(weights[..., 0:1]) * 1e-6, # Very small non-zero scale - torch.full_like(weights, fill_value=128, dtype=torch.uint8), # 128 = 0 in symmetric + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros_like(weights, dtype=torch.uint8), torch.zeros_like(weights), ) - # Convert to int8 range (-128 to 127) - scaled_weights = torch.round(weights / scale) - clipped_weights = torch.clamp(scaled_weights, -128, 127) + # Quantize: round(weight / scale) then clamp to [-128, 127] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-128, 127).float() - # Convert from int8 signed range [-128,127] to uint8 storage range [0,255] - # by adding 128 to map -128->0, -127->1, ..., 127->255 - quant_weights = (clipped_weights + 128).to(torch.uint8) + # For symmetric quantization, we use signed int8 representation + # Convert to uint8 storage: shift [-128,127] -> [0,255] for storage only + storage_weights = (quantized_weights + 128).to(torch.uint8) - # Dequantize - convert back from uint8 to int8 by subtracting 128, then multiply by scale - # Make sure scale has the right shape for broadcasting - scale_expanded = scale.float() - if scale_expanded.dim() < quant_weights.dim(): - for _ in range(quant_weights.dim() - scale_expanded.dim()): - scale_expanded = scale_expanded.unsqueeze(-1) - result = ((quant_weights.float() - 128) * scale_expanded).to(dtype=weights.dtype) - return scale.to(torch.float16), quant_weights, result + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Convert back to signed values: [0,255] -> [-128,127] and apply scale + signed_weights = storage_weights.float() - 128.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), storage_weights, result.to(weights.dtype) def create_cpu_moe_onnx_graph( @@ -254,38 +246,23 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, + swiglu_interleaved=False, ): - # Make sure we have onnx available before proceeding - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Define intermediate_size variable consistently inter_size = intermediate_size topk = top_k - # Note: SwiGLU requires 2 components (gate and value) - # Force use_quant to True - we only want to test QMoE use_quant = True - # Note: In QMoE, biases are not used at all, only scales - # The following parameters are only relevant when use_quant=False (which is never the case here) - # fc1_bias and fc2_bias are completely ignored for QMoE - - # Ensure all variables are properly initialized for safety if fc1_scales is None and use_quant: - print("Warning: fc1_scales is None but quantization is enabled") return None if fc2_scales is None and use_quant: - print("Warning: fc2_scales is None but quantization is enabled") return None - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Using uint8 storage type with symmetric quantization - # 4-bit: range = [-8, 7] (stored as uint8 values [0, 15]) - # 8-bit: range = [-128, 127] (stored as uint8 values [0, 255]) assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" assert fc1_scales is not None, "FC1 scales must be provided for QMoE" @@ -293,12 +270,9 @@ def create_cpu_moe_onnx_graph( assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" - # Make sure we have onnx available before proceeding - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Always use QMoE, never MoE op_name = "QMoE" inputs = [ "input", @@ -311,10 +285,6 @@ def create_cpu_moe_onnx_graph( "", ] - # Note: In QMoE mode, biases are not used at all - # This code path is never executed since use_quant is always True - - # Use SwiGLU activation if specified, otherwise use SiLU activation = "swiglu" if use_swiglu else "silu" nodes = [ @@ -324,8 +294,13 @@ def create_cpu_moe_onnx_graph( ["output"], "MoE_0", k=topk, - normalize_routing_weights=0, + normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior activation_type=activation, + # Add new attributes with backwards-compatible default values + swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_limit=7.0, + activation_alpha=1.702, + activation_beta=1.0, domain="com.microsoft", ), ] @@ -333,15 +308,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - # For 4-bit quantization, we need to pack 2 values into each byte - pack_factor = 2 if quant_bits == 4 else 1 - - # For SwiGLU, we need to double the FC1 dimension to accommodate both gate and value paths - act_factor = 2 if use_swiglu else 1 - - # FC1 shape needs to account for both SwiGLU and quantization packing - fc1_shape = [num_experts, hidden_size, (act_factor * inter_size) // pack_factor] - fc2_shape = [num_experts, inter_size, hidden_size // pack_factor] + # Weights are store in column major order. Need pack 2 int4 values into uint8. + # Use the actual tensor shapes instead of calculating them to avoid size mismatches + fc1_shape = list(fc1_experts_weights.shape) + fc2_shape = list(fc2_experts_weights.shape) torch_dtype = onnx_to_torch_type_map[onnx_dtype] @@ -354,39 +324,90 @@ def create_cpu_moe_onnx_graph( weight_onnx_type, fc1_shape, fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), - raw=False, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), - raw=False, ), ] - # QMoE always uses scales, never biases - # For SwiGLU, FC1 scales shape needs to be doubled to account for gate and value components fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + initializers.extend( [ helper.make_tensor( "fc1_scales", onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_dtype).flatten().tolist() - if fc1_scales is not None - else [1.0] * (num_experts * inter_size), + fc1_scale_data, raw=False, ), helper.make_tensor( "fc2_scales", onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_dtype).flatten().tolist() - if fc2_scales is not None - else [1.0] * (num_experts * hidden_size), + fc2_scale_data, raw=False, ), ] @@ -427,13 +448,6 @@ def __getitem__(self, key): return cls(**kwargs) -ACT2CLS = { - "silu": nn.SiLU, - "gelu": nn.GELU, -} -ACT2FN = ClassInstantier(ACT2CLS) - - class PhiMoEConfig: def __init__( self, @@ -452,7 +466,96 @@ def __init__( self.router_jitter_noise = router_jitter_noise +class SwigluMoeConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + num_local_experts=8, + num_experts_per_token=2, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_token = num_experts_per_token + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class MoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: PhiMoEConfig): + super().__init__(config) + + +class PhiMoESwiGLUMLP(nn.Module): + """ + Phi3 MoE expert converted to 2-weight SwiGLU structure for CPU compatibility. + This converts the traditional 3-weight Phi3 structure to SwiGLU format. + """ + + def __init__(self, config: PhiMoEConfig): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + """ + Updated to match the CUDA implementation's routing logic for fair comparison. + This now uses the same complex jitter-based masking approach as the CUDA tests. + """ assert top_k == 2 assert not training @@ -467,8 +570,6 @@ def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): dim=-1, index=selected_experts[:, 0].unsqueeze(-1) ) - ################ second expert gating ################ - mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) factor = scores.abs().clamp(min=mask_logits_threshold_2) @@ -489,60 +590,6 @@ def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): ) -class MoEBlockSparseTop2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): - def __init__(self, config: PhiMoEConfig, use_swiglu=False): - super().__init__(config) - self.use_swiglu = use_swiglu - - def forward(self, hidden_states): - if self.use_swiglu: - # SwiGLU implementation matching C++ implementation exactly - gate_output = self.w1(hidden_states) # Gate - value_output = self.w3(hidden_states) # Value - - # Apply SwiGLU exactly as in the C++ implementation - # C++ uses swiglu_alpha = 1.702f and clamp_limit = 7.0f - swiglu_alpha = 1.702 - clamp_limit = 7.0 - - # Apply clamping to match C++ implementation - gate_output = torch.clamp(gate_output, max=clamp_limit) # Clamp max only for gate - value_output = torch.clamp(value_output, min=-clamp_limit, max=clamp_limit) # Clamp both for value - - # Compute gate activation: gate * sigmoid(alpha * gate) - sigmoid_input = swiglu_alpha * gate_output - sigmoid_output = torch.sigmoid(sigmoid_input) - swish_output = gate_output * sigmoid_output - - # Multiply by (value + 1) as done in C++ - current_hidden_states = swish_output * (value_output + 1.0) - - # Apply FC2 - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - else: - # Original implementation with standard activation - return super().forward(hidden_states) - - class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() @@ -555,7 +602,6 @@ def __init__(self, quant_bits=0, onnx_dtype=None): def create_ort_session(self, moe_onnx_graph): if moe_onnx_graph is None: - print("No ONNX graph provided, skipping session creation") return None sess_options = onnxruntime.SessionOptions() @@ -563,9 +609,7 @@ def create_ort_session(self, moe_onnx_graph): try: ort_session = onnxruntime.InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) - except Exception as e: - print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") - print("Skipping ONNX Runtime execution for this test case.") + except Exception: return None return ort_session @@ -574,20 +618,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: - # If session creation failed, we can't run inference if self.ort_sess is None: - print("No ORT session available, skipping ONNX Runtime execution") return None batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states_flat) - # Determine the correct torch dtype from the onnx_dtype torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - # Prepare tensors on the correct device for ORT inference with the CORRECT dtype tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), @@ -595,11 +634,9 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False } try: - # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() for name, tensor in tensors.items(): - # Ensure tensor is on the globally defined device if name == "output": iobinding.bind_output( name=name, @@ -624,200 +661,63 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False iobinding.synchronize_outputs() if enable_performance_test: - import time # noqa: PLC0415 - - repeat = 100 # Using fewer repeats for CPU tests + repeat = 100 s = time.time() for _ in range(repeat): iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) iobinding.synchronize_outputs() e = time.time() - print(f"QMoE CPU kernel time: {(e - s) / repeat * 1000} ms") + time_ms = (e - s) / repeat * 1000 + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") - # The output tensor is on `device`. Reshape and return it. return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) except Exception as e: - print(f"Error running ORT session: {e!s}") raise - def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) - torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state) - - # If no ORT output was produced, we can't do a parity check - if ort_output is None: - print("ORT execution failed or is not supported, skipping parity check") - return - - dtype_str = ort_dtype_name_map[self.onnx_dtype] - max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() - non_finite = torch.isnan(max_diff) or torch.isinf(max_diff) - - print( - f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," - f" batch: {self.batch_size}, seq_len: {self.sequence_length}," - f" max_diff: {max_diff}" - ) - - # Report if NaN or Inf values are detected - if non_finite: - print( - "Warning: NaN or Inf values detected in the output difference. Numerical comparisons will be limited." - ) - - # Maps "ort_type:quant_bits" to (atol, rtol) - # Note: Now that both CPU and CUDA use symmetric quantization, - # we can use more consistent tolerances across implementations. - ort_dtype_quant_bits_tolerance_map = { - "FP32:0": (5e-3, 1e-3), - "FP16:0": (5e-2, 1e-3), - "FP16:4": (2.0, 8e-3), # Improved tolerance with symmetric quantization - "FP16:8": (1.5, 8e-3), # Improved tolerance with symmetric quantization - } - - tolerance_key = f"{dtype_str}:{self.quant_bits}" - if tolerance_key not in ort_dtype_quant_bits_tolerance_map: - print(f"Warning: No tolerance defined for {tolerance_key}, using default") - atol, rtol = 10.0, 1e-1 - else: - atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] - - # Report stats but don't assert (just for information) - # Handle NaN/Inf values more gracefully - try: - diff = (torch_output.cpu() - ort_output.cpu()).abs() - mean_diff = diff.mean().item() if not torch.isnan(diff.mean()) else float("nan") - median_diff = diff.median().item() if not torch.isnan(diff.median()) else float("nan") - p95_diff = ( - torch.quantile(diff, 0.95).item() if not torch.isnan(torch.quantile(diff, 0.95)) else float("nan") - ) - - print(f"Stats - Mean diff: {mean_diff}, Median diff: {median_diff}, 95th percentile: {p95_diff}") - - # Check if results are within tolerance - max_diff_val = max_diff.item() - if not non_finite and max_diff_val > atol: - print(f"Warning: Maximum difference ({max_diff_val:.6f}) exceeds absolute tolerance ({atol:.6f})") - elif not non_finite: - print(f"Success: All values within absolute tolerance ({atol:.6f})") - - # For quantized models, the relative difference can be very large for small values - # This is because quantization has a greater effect on small values than large ones - # Add a larger epsilon to prevent misleading large relative differences for near-zero values - # Safely compute relative differences - if not non_finite: - relative_diff = diff / torch.max(torch_output.cpu().abs(), torch.tensor(1e-3)) - max_rel_diff = relative_diff.max().item() - rel_exceeds = (relative_diff > rtol).float().mean().item() * 100 - - if max_rel_diff > rtol: - print( - f"Warning: Maximum relative difference ({max_rel_diff:.6f}) exceeds relative tolerance ({rtol:.6f})" - ) - print(f"Percentage of values exceeding relative tolerance: {rel_exceeds:.2f}%") - else: - print(f"Success: All relative differences within relative tolerance ({rtol:.6f})") - except Exception as e: - # If any calculation fails, just log it but don't crash the test - print(f"Warning: Error calculating statistics: {e}") - - # Note: Higher relative differences are expected in quantized models - # This is because quantization inherently introduces error, especially for small values - # The key metric is the absolute difference, which we've significantly improved - - def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) - self.ort_forward(hidden_state, enable_performance_test=True) - - -class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accommodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - - CPU version: Modified to use only FC1 and FC2 for CPU compatibility. - - Quantization: Uses symmetric quantization to exactly match the C++ implementation: - - 4-bit: range = [-8, 7] (stored as uint8 values [0, 15]) - - 8-bit: range = [-128, 127] (stored as uint8 values [0, 255]) - This ensures the test exactly simulates the C++ implementation with full - compatibility with the CUDA implementation and TensorRT. - """ - - def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None, use_swiglu=False): - # Ensure we always have a valid quantization bits value (4 or 8) before passing to parent - if quant_bits <= 0: - print("Warning: quant_bits was set to 0 or negative, forcing to 4-bit") - quant_bits = 4 - - # Now pass the validated quant_bits to parent constructor - super().__init__(quant_bits, onnx_dtype) - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - self.router_jitter_noise = config.router_jitter_noise - self.use_swiglu = use_swiglu - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - # Use PhiMoEBlockSparseTop2MLP for all experts - self.experts = nn.ModuleList( - [PhiMoEBlockSparseTop2MLP(config, use_swiglu=self.use_swiglu) for _ in range(self.num_experts)] - ) + def recreate_onnx_model(self): + """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] w1_scale_list, w2_scale_list = [], [] - # Always use quantization for QMoE is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - # Quantize the weights w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) - # For SwiGLU, we also need to quantize w3 (value) weights - w3_qdq = None # Initialize w3_qdq to avoid unbound variable error if self.use_swiglu: - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) - # Combine gate (w1) and value (w3) for SwiGLU - if is_4_bit: - # For 4-bit, we need to combine the packed weights in the right format - # Double the intermediate size for SwiGLU (gate + value) - # Each byte contains two 4-bit values + if self.swiglu_interleaved: + pass + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + gate_weights = pre_qweight1 value_weights = pre_qweight3 - # Create a new tensor with double the last dimension - combined_shape = list(gate_weights.shape) - combined_shape[-1] *= 2 # Double the last dimension for gate+value - combined_weights = torch.zeros(combined_shape, dtype=torch.uint8, device=gate_weights.device) - combined_weights[..., : gate_weights.shape[-1]] = gate_weights - combined_weights[..., gate_weights.shape[-1] :] = value_weights - pre_qweight1 = combined_weights - else: - # For 8-bit, we can just concatenate along the last dimension - pre_qweight1 = torch.cat([pre_qweight1, pre_qweight3], dim=-1) + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) - # Same for scales - combine gate and value scales - w1_scale = torch.cat([w1_scale, w3_scale], dim=-1) + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() - # Update the expert weights with dequantized values for PyTorch execution - self.experts[i].w1.weight.data = w1_qdq - self.experts[i].w2.weight.data = w2_qdq - if self.use_swiglu and w3_qdq is not None: - self.experts[i].w3.weight.data = w3_qdq + self.experts[i].w2.weight.data = w2_qdq.contiguous().clone() - # Store the quantized weights and scales for ONNX model w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) w1_scale_list.append(w1_scale) @@ -826,14 +726,14 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) - # Always use scales for QMoE moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) - self.batch_size = batch_size - self.sequence_length = sequence_length + if moe_experts_weight_scale1.dim() == 3: + moe_experts_weight_scale1 = moe_experts_weight_scale1.squeeze(-1) + if moe_experts_weight_scale2.dim() == 3: + moe_experts_weight_scale2 = moe_experts_weight_scale2.squeeze(-1) - # Use CPU specific graph creation try: self.moe_onnx_graph = create_cpu_moe_onnx_graph( hidden_size=self.hidden_dim, @@ -841,49 +741,277 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype num_experts=self.num_experts, top_k=self.top_k, intermediate_size=self.ffn_dim, - torch_dtype=torch.float32, # Assuming float32 as default + torch_dtype=torch.float32, onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE, only passed as None for API compatibility + # Biases are not used in QMoE fc1_bias=None, fc2_bias=None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, - use_swiglu=self.use_swiglu, # Use SwiGLU if specified + use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, ) - except Exception as e: - print(f"Error creating ONNX graph: {e}") + except Exception: self.moe_onnx_graph = None + return False self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + return self.ort_sess is not None + + def parity_check(self): + model_updated = self.recreate_onnx_model() + if not model_updated: + return + + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + if ort_output is None: + return + + torch_has_nan = torch.isnan(torch_output).any() + ort_has_nan = torch.isnan(ort_output).any() + torch_has_inf = torch.isinf(torch_output).any() + ort_has_inf = torch.isinf(ort_output).any() + + if torch_has_nan or ort_has_nan or torch_has_inf or ort_has_inf: + torch_output_clean = torch.where( + torch.isnan(torch_output) | torch.isinf(torch_output), torch.zeros_like(torch_output), torch_output + ) + ort_output_clean = torch.where( + torch.isnan(ort_output) | torch.isinf(ort_output), torch.zeros_like(ort_output), ort_output + ) + max_diff = (torch_output_clean.cpu() - ort_output_clean.cpu()).abs().max() + + if (torch_has_nan and ort_has_nan) or (torch_has_inf and ort_has_inf): + problematic_torch = torch.isnan(torch_output) | torch.isinf(torch_output) + problematic_ort = torch.isnan(ort_output) | torch.isinf(ort_output) + if torch.equal(problematic_torch, problematic_ort): + max_diff = 0.0 + else: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() + + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + + print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (0.05, 0.01), + "FP16:8": (0.02, 0.01), + "FP32:4": (0.11, 0.01), + "FP32:8": (0.11, 0.01), + } + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + tolerance_key = f"{dtype_str}:{self.quant_bits}" + if tolerance_key in ort_dtype_quant_bits_tolerance_map: + base_atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] + + if max_diff > base_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"tolerance {base_atol:.6f} for {tolerance_key}" + ) + else: + fallback_atol = 0.1 + if max_diff > fallback_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"fallback tolerance {fallback_atol:.6f} for unknown config {tolerance_key}" + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +def small_test_cases(): + for batch_size in [1, 4]: + for sequence_length in [32, 128]: + yield batch_size, sequence_length + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = None + self.ort_sess = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([PhiMoESwiGLUMLP(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) - routing_weights, selected_experts = masked_sampling_omp_inference( - router_logits, + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, top_k=self.top_k, - jitter_eps=self.router_jitter_noise, - training=False, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=fc1_experts_weights, + fc2_experts_weights=fc2_experts_weights, + fc1_bias=fc1_experts_bias, + fc2_bias=fc2_experts_bias, + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=use_quant, + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved, ) + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """PyTorch reference forward pass using SwiGLU-style routing""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) @@ -891,106 +1019,99 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states -def small_test_cases(): - for batch_size in [1, 4]: - for sequence_length in [32, 128]: - yield batch_size, sequence_length +disable_cpu_qmoe_tests = False +# Define test cases for different MoE types +phi3_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] -# Define our test cases for QMoE (4-bit and 8-bit quantization) -# Only test QMoE since standard MoE is not supported on CPU -cpu_phi3_test_cases = list( - itertools.product( - [1, 4], # batch_size - [8, 32], # sequence_length - smaller sequence lengths for CPU - [4, 8], # quant_bits - only test QMoE (4-bit and 8-bit) - [False], # use_swiglu - standard SiLU cases - ) -) - -# Additional test cases for SwiGLU activation -cpu_phi3_swiglu_test_cases = list( - itertools.product( - [1, 4], # batch_size - [8, 32], # sequence_length - smaller sequence lengths for CPU - [4, 8], # quant_bits - only test QMoE (4-bit and 8-bit) - [True], # use_swiglu - SwiGLU activation - ) -) -# Temporarily disable CPU qMoE tests. A fix will come soon. -disable_cpu_qmoe_tests = True +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestPhiQMoECPU(unittest.TestCase): + @parameterized.expand(phi3_test_cases) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running Phi3 QMoE test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + + +disable_cpu_qmoe_tests = False + +swiglu_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") -class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(cpu_phi3_test_cases + cpu_phi3_swiglu_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, use_swiglu=False): - activation_type = "SwiGLU" if use_swiglu else "SiLU" - print( - f"Running PhiMoE CPU test with batch_size={batch_size}, sequence_length={sequence_length}, " - f"quant_bits={quant_bits}, activation={activation_type}" +class TestSwigluQMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running SwiGLU test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, ) - config = PhiMoEConfig(hidden_size=256, intermediate_size=512, hidden_act="silu") # Smaller sizes for CPU tests - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits, use_swiglu=use_swiglu) - phi3_moe.to(device) - # Skip tests if ONNX is not available - if not HAS_ONNX: - self.skipTest("ONNX is not installed") + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) - # Skip if the session creation failed - if phi3_moe.ort_sess is None: - self.skipTest("Failed to create ONNX Runtime session - CPU MoE operator not available") + torch_result = swiglu_moe.forward(hidden_states) - try: - phi3_moe.parity_check() - except RuntimeError as e: - if "FC3 gating is not yet implemented on CPU" in str(e): - self.skipTest("FC3 gating is not yet implemented on CPU") - else: - raise - - @parameterized.expand([(8, False), (4, False), (8, True), (4, True)]) - def test_phi3_qmoe_cpu_benchmark(self, quant_bits, use_swiglu=False): - activation_type = "SwiGLU" if use_swiglu else "SiLU" - print(f"Benchmarking PhiMoE CPU with quant_bits={quant_bits}, activation={activation_type}") - batch_size = 1 - sequence_length = 32 - config = PhiMoEConfig(hidden_size=256, intermediate_size=512) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits, use_swiglu=use_swiglu) - phi3_moe.to(device) - - # Skip tests if ONNX is not available or session creation failed - if not HAS_ONNX or phi3_moe.ort_sess is None: - self.skipTest("ONNX not installed or CPU MoE operator not available") - return + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) - try: - phi3_moe.benchmark_ort() - except RuntimeError as e: - if "FC3 gating is not yet implemented on CPU" in str(e): - self.skipTest("FC3 gating is not yet implemented on CPU") - else: - raise + swiglu_moe.parity_check() if __name__ == "__main__": From 0de1c01e1208335f62cfe161eaf2c79f6e7ba31a Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 26 Aug 2025 15:58:46 -0700 Subject: [PATCH 10/22] Enable ABSL_FLAGS flag registration for onnxruntime_perf_test for mobile build (#25849) ### Description `ABSL_FLAGS_STRIP_NAMES `is set to 1 by default to disable flag registration when building for Android, iPhone, and "embedded devices". So, running onnxruntime_perf_test on Android will see that flags are not registered. image (2) Set `ABSL_FLAGS_STRIP_NAMES ` to 0 by default for all builds. --- cmake/onnxruntime_unittests.cmake | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5b78235988413..6847db64004ca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1232,6 +1232,12 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${onnxruntime_perf_test_src_patterns} ) onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) + + # ABSL_FLAGS_STRIP_NAMES is set to 1 by default to disable flag registration when building for Android, iPhone, and "embedded devices". + # See the issue: https://github.com/abseil/abseil-cpp/issues/1875 + # We set it to 0 for all builds to be able to use ABSL flags for onnxruntime_perf_test. + target_compile_definitions(onnxruntime_perf_test PRIVATE ABSL_FLAGS_STRIP_NAMES=0) + if(MSVC) target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") From f58f7eb7fa3c8dbcd5d2bf8fb03a6072ea345dce Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 27 Aug 2025 07:45:27 +0800 Subject: [PATCH 11/22] [webgpu] Expand Unsqueeze version to 23 (#25858) ### Description The phi4 mini in Edge is using ai.onnx v21. Without this change, it results a `MemcpyToHost` inserted and slows the generation speed. --- .../core/providers/webgpu/tensor/unsqueeze.cc | 26 ++++++++++++++++++- .../webgpu/webgpu_execution_provider.cc | 8 ++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc index 4bcef4fd79296..b89623110217d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -11,7 +11,31 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Unsqueeze, kOnnxDomain, - 13, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, 20, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 2afe8a964a003..bbb3fbdd221d3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -249,7 +249,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); @@ -548,7 +550,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 7e3174b0c17673b6e40157389457bba619ab7a84 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 27 Aug 2025 07:48:01 +0800 Subject: [PATCH 12/22] [webgpu] Optimize dp4 prefill shader for Qualcomm (#25578) This change uses subgroupShuffle for sg_size=64 to perform the matmul. It also uses a loop instead of loop unrolling to reduce the register pressure. Phi4 prefill for 1K tokens becomes 8.8s from 11.32s on Qualcomm Adreno X1-85 GPU. --- .../quantization/dp4a_matmul.wgsl.template | 93 +++++++++++++++++-- .../webgpu/quantization/dp4a_matmul_nbits.cc | 6 +- .../webgpu/quantization/dp4a_matmul_nbits.h | 10 +- 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index bea6588ea72eb..ee6dde3788157 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -4,6 +4,7 @@ #param block_size #param n_bits #param has_zero_points +#param is_qualcomm #include "quantization/dp4a_matmul_common.wgsl.template" @@ -138,18 +139,35 @@ $MAIN { // During the compute phase, we have the 64x64 tile split into // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; + var subtile_id = u32(local_idx / subtile_size); + var subtile_idx = u32(subtile_id / 4); + var subtile_idy = u32(subtile_id % 4); + var base_A = subtile_idx * 16; + var base_B = subtile_idy * 16; // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); + var a_idx = u32(local_idx % subtile_size); +#if is_qualcomm + // subtile_idx is always 0 + // subtile_idy is one of {0,1,2,3} + // The subtile is now rectangular 64x16 for qualcomm case and we have 4 subtiles, this way we don't need to + // increase the number of lane_output each thread needs to track. That is if we want to use a subtile that is 64x64 + // we would need var lane_outputs: array; + if (sg_size == 64) { + subtile_id = u32(local_idx / sg_size); + subtile_idx = u32(subtile_id / 4); + subtile_idy = u32(subtile_id % 4); + base_A = subtile_idx * sg_size; + base_B = subtile_idy * 16; + a_idx = sg_id; + } + var lane_outputs: array; +#else var lane_output1: vec4; var lane_output2: vec4; var lane_output3: vec4; var lane_output4: vec4; +#endif // K's vectorization is 16 items per index. See input_a/input_b. // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is // k tile size is 32. In vectorized space that is 32/16 = 2. @@ -173,6 +191,34 @@ $MAIN { var own_scale_a: output_element_t = scale_A[base_A + a_idx]; #if has_zero_points && n_bits == 8 + #if is_qualcomm + if (sg_size == 64) + { + var own_b0: vec4; + var own_b1: vec4; + var own_scale_b: output_element_t; + var zero: i32; + if (sg_id < 16) + { + own_b0 = tile_B[0][base_B + sg_id]; + own_b1 = tile_B[1][base_B + sg_id]; + own_scale_b = scale_B[base_B + sg_id]; + zero = zeroes[base_B + sg_id]; + } + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a, subgroupShuffle(zero, i)); + } + } + else + { + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i], zeroes[base_B + i]); + } + } + #else if (sg_size == 16) { var own_b0: vec4 = tile_B[0][base_B + sg_id]; @@ -225,7 +271,34 @@ $MAIN { lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]); lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]); } + #endif #else + #if is_qualcomm + if (sg_size == 64) + { + var own_b0: vec4; + var own_b1: vec4; + var own_scale_b: output_element_t; + if (sg_id < 16) + { + own_b0 = tile_B[0][base_B + sg_id]; + own_b1 = tile_B[1][base_B + sg_id]; + own_scale_b = scale_B[base_B + sg_id]; + } + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a); + } + } + else + { + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i]); + } + } + #else if (sg_size == 16) { var own_b0: vec4 = tile_B[0][base_B + sg_id]; @@ -277,6 +350,7 @@ $MAIN { lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); } + #endif #endif workgroupBarrier(); } @@ -287,9 +361,16 @@ $MAIN { // This creates a shader requirement that uniforms.N % 16 == 0 if (a_global < uniforms.M && b_global < uniforms.N) { +#if is_qualcomm + output[output_idx] = vec4(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]); + output[output_idx+1] = vec4(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]); + output[output_idx+2] = vec4(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]); + output[output_idx+3] = vec4(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]); +#else output[output_idx] = lane_output1; output[output_idx+1] = lane_output2; output[output_idx+2] = lane_output3; output[output_idx+3] = lane_output4; +#endif } } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 61764979a5dd6..84954946fa6be 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -28,6 +28,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template", WGSL_TEMPLATE_PARAMETER(block_size, block_size_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), + WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true)); } @@ -118,7 +119,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor TensorShape reshaped_y_shape{1, M, N / kVec4Components}; uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; - DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points}; + bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, @@ -133,7 +135,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {num_N_tile}, {zero_blocks_per_col}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points); + .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 470a4e187f37a..b00392cbb291e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -21,10 +21,11 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBits"}, - block_size_(block_size), - nbits_(nbits), - has_zero_points_(has_zero_points) {} + DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, + block_size_(block_size), + nbits_(nbits), + has_zero_points_(has_zero_points), + is_qualcomm_(is_qualcomm) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -39,6 +40,7 @@ class DP4AMatMulNBitsProgram final : public Program { uint32_t block_size_; uint32_t nbits_; bool has_zero_points_; + bool is_qualcomm_; }; class DP4AMatMulNBitsSmallMProgram final : public Program { From 16ae99ede405d3d6c59d7cce80c53f5f7055aeed Mon Sep 17 00:00:00 2001 From: umangb-09 Date: Wed, 27 Aug 2025 21:23:01 +0530 Subject: [PATCH 13/22] Add cuda graph implementation for NV TRT RTX EP (#25787) ### Description This change adds CUDA Graph support to the NV TensorRT RTX Execution Provider (EP). ### Motivation and Context Integrating CUDA Graphs into the NV TRT RTX EP provides: Lower latency by minimizing per-kernel launch overhead. Better throughput for repeated inference runs. Improved efficiency on GPUs with high kernel launches overhead sensitivity. --------- Co-authored-by: Maximilian Mueller Co-authored-by: Gaurav Garg --- .../nv_tensorrt_rtx/nv_provider_options.h | 2 +- onnxruntime/core/providers/cuda/cuda_graph.cc | 6 +- onnxruntime/core/providers/cuda/cuda_graph.h | 2 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 347 ++++++++++++------ .../nv_tensorrt_rtx/nv_execution_provider.h | 47 ++- 5 files changed, 267 insertions(+), 137 deletions(-) diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index dc27204017caa..a32f465e44adf 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -31,7 +31,7 @@ constexpr const char* kDetailedBuildLog = "nv_detailed_build_log"; constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; -constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; +constexpr const char* kCudaGraphEnable = "enable_cuda_graph"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 8353c654681fc..88e58aec70550 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -72,7 +72,7 @@ void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " @@ -81,7 +81,9 @@ Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + if (sync_status_flag) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 064b526e604bc..6b61a66671de4 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -38,7 +38,7 @@ struct CUDAGraphManager { void SetStream(cudaStream_t stream); void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); - Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag = true); void Reset(); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 451be69c81cfb..b7997ce86737a 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -19,6 +19,7 @@ #include "nv_data_transfer.h" #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/cuda_graph.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" @@ -255,7 +256,8 @@ bool ApplyProfileShapesFromProviderOptions(std::vector>>& profile_min_shapes, std::unordered_map>>& profile_max_shapes, std::unordered_map>>& profile_opt_shapes, - ShapeRangesMap& input_explicit_shape_ranges) { + ShapeRangesMap& input_explicit_shape_ranges, + bool& cuda_graph_flag) { if (trt_profiles.size() == 0) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Number of optimization profiles should be greater than 0, but it's 0."; return false; @@ -282,6 +284,10 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorisShapeTensor()) { + if (cuda_graph_flag) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Shape tensor detected on input '" << input->getName() << "'. Disabling CUDA Graph."; + cuda_graph_flag = false; + } int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); @@ -715,10 +721,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx, } NvExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { - // TODO: figure out if PerThreadContext is used at all. If not, just clean it up. + // Only set device if user hasn't provided a compute stream if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); - (void)(stream); + (void)stream; } } @@ -745,6 +751,86 @@ bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fu return false; } +void NvExecutionProvider::PerThreadContext::DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_.erase(cuda_graph_annotation_id); + cuda_graph_.Reset(); +} + +void NvExecutionProvider::PerThreadContext::ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id) { + if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) { + return; + } + graph_id_to_run_count_[cuda_graph_annotation_id] = 0; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) { + return false; + } + + // Safe access to map - return false if key doesn't exist yet + auto it = graph_id_to_run_count_.find(cuda_graph_annotation_id); + if (it == graph_id_to_run_count_.end()) { + return false; // Entry doesn't exist yet, not ready for capture + } + + bool allowed = it->second >= min_num_runs_before_cuda_graph_capture_; + if (allowed) { + LOGS_DEFAULT(VERBOSE) << "NvTensorRTRTX EP Graph capture allowed for ID: " << cuda_graph_annotation_id + << ", run count: " << it->second; + } + return allowed; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const { + // Actual implementation + auto graph_annotation_str = run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + CudaGraphAnnotation_t cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + + // Kind of debugging head implementation, can be cleaned and made robust like CUDA EP + if (graph_annotation_str.has_value() && !graph_annotation_str->empty()) { + if (!TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to parse cuda graph annotation id: " + << *graph_annotation_str << ", using default: " << kCudaGraphAnnotationDefault; + cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + } + } + return cuda_graph_annotation_id; +} + +void NvExecutionProvider::PerThreadContext::SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id) { + current_graph_annotation_id_ = cuda_graph_annotation_id; +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCurrentGraphAnnotationId() const { + return current_graph_annotation_id_; +} + +void NvExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); +} + +void NvExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(cuda_graph_annotation_id); +} + +Status NvExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { + return cuda_graph_.Replay(cuda_graph_annotation_id, sync_status_flag); +} + +void NvExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_[cuda_graph_annotation_id]++; +} + bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { @@ -846,6 +932,12 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + } else if (cuda_graph_enable_) { + external_stream_ = false; + CUDA_CALL_THROW(cudaStreamCreate(&stream_)); + } else { + external_stream_ = false; + stream_ = nullptr; // Will be created in compute function } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1010,7 +1102,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) // external stream: // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. // So, no need to synchronize different streams after enqueueV3. - if (cuda_graph_enable_ || external_stream_) { + if (external_stream_) { sync_stream_after_enqueue_ = false; } @@ -1038,7 +1130,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_force_sequential_engine_build: " << force_sequential_engine_build_ << ", nv_sparsity_enable: " << sparsity_enable_ << ", nv_auxiliary_streams: " << auxiliary_streams_ - << ", nv_cuda_graph_enable: " << cuda_graph_enable_ + << ", enable_cuda_graph: " << cuda_graph_enable_ << ", nv_dump_ep_context_model: " << dump_ep_context_model_ << ", nv_ep_context_file_path: " << ep_context_file_path_ << ", nv_ep_context_embed_mode: " << ep_context_embed_mode_ @@ -1060,7 +1152,7 @@ NvExecutionProvider::~NvExecutionProvider() { } } - if (!external_stream_ && stream_) { + if (!external_stream_ && stream_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); @@ -1072,41 +1164,82 @@ NvExecutionProvider::~NvExecutionProvider() { } } +void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, + CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture) { + graph_replay_on_this_run = false; + should_start_capture = false; + + // Case 1: CUDA Graph capture is enabled AND IO binding is required. + // In this case, we force graph re-capture by resetting warmup runs. + // If a graph for this annotation ID already exists, delete it before proceeding. + if (require_io_binding && cuda_graph_enable_) { + GetPerThreadContext().ResetWarmupRuns(cuda_graph_annotation_id); + + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Graph already captured and required_io_binding is true, resetting warmup runs and deleting graph"; + GetPerThreadContext().DeleteCapturedGraph(cuda_graph_annotation_id); + } + // Case 2: CUDA Graph capture is enabled AND IO binding is NOT required + } else if (cuda_graph_enable_ && !require_io_binding) { + // If the graph is not yet captured, increment the regular run counter + if (cuda_graph_annotation_id != kCudaGraphAnnotationSkip && + !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id); + } + + // If capture is allowed and graph not already captured, + // set the stream and begin capture + if (!GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().SetCudaGraphStream(stream); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); + should_start_capture = true; + } + + // If a graph is already captured for this ID, mark it for replay in this run. + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + graph_replay_on_this_run = true; + } + } +} + bool NvExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool NvExecutionProvider::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + // This is hardcoded to always return false because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return false; } -void NvExecutionProvider::CaptureBegin(int) { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(0); +Status NvExecutionProvider::ReplayGraph(int graph_annotation_id) { + // This is hardcoded to always return OK because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return Status::OK(); } -void NvExecutionProvider::CaptureEnd(int) { - cuda_graph_.CaptureEnd(0); - is_graph_captured_ = true; -} +Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + if (cuda_graph_enable_) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + GetPerThreadContext().SetCurrentGraphAnnotationId(cuda_graph_annotation_id); + } -bool NvExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; + if (multi_profile_enable_ == true) { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); + TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); + } + return Status::OK(); } -Status NvExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); - // Please note that CUDAGraph::Replay() is not thread safe. - // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(), - // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_.Replay(0); -} +Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + (void)run_options; -void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { - // Please note that this function is not thread safe. - // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), - // therefore following increment is guaranteed to be thread safe. - ++regular_run_count_before_graph_capture_; + if (sync_stream && external_stream_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } + return Status::OK(); } std::vector NvExecutionProvider::CreatePreferredAllocators() { @@ -1133,22 +1266,6 @@ std::unique_ptr NvExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { - if (multi_profile_enable_ == true) { - auto graph_annotation_str = - run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); - TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); - } - return Status::OK(); -} - -Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (sync_stream && external_stream_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); - } - return Status::OK(); -} - // Get the pointer to the IBuilder instance. // Note: This function is not thread safe. Calls to this function from different threads must be serialized // even though it doesn't make sense to have multiple threads initializing the same inference session. @@ -2519,7 +2636,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_opt_shapes_.find(input_name) != profile_opt_shapes_.end() && profile_max_shapes_.find(input_name) != profile_max_shapes_.end(); if (has_explicit_profile && tensor_has_profile) { - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } else { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Creating implicit profile for tensor " << input_name; profile_min_shapes_[input_name] = std::vector>{{}}; @@ -2546,7 +2663,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_max_shapes_[input_name][0][idx_dim] = dim_value; } } - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } if (!apply_profile) { std::ostringstream msg; @@ -2600,6 +2717,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Otherwise engine will be handled at inference time. std::unique_ptr trt_engine; std::unique_ptr trt_context; + std::unique_ptr trt_runtime_config; // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { @@ -2622,6 +2740,13 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to deserialize engine for fused node: " + fused_node.Name()); } + + trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); + } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); LOGS_DEFAULT(INFO) << "TensorRT engine build for " << fused_node.Name() << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; @@ -2681,7 +2806,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + trt_context = std::unique_ptr(trt_engine->createExecutionContext(trt_runtime_config.get())); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); @@ -2777,9 +2902,17 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; + } else { + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + stream_ = stream; + } if (multi_profile_enable_ == true) { if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) @@ -2860,7 +2993,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); if (status != Status::OK()) { @@ -2886,18 +3018,23 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; + + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); + + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -2914,10 +3051,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -2951,21 +3093,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } } - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - return Status::OK(); }; @@ -3098,9 +3225,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; + } else { + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + } // Check before using trt_engine if (trt_engine == nullptr) { @@ -3203,18 +3337,23 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; + + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); + + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -3231,10 +3370,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -3268,21 +3412,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } } - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - return Status::OK(); }; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index e3dd38eb837ff..22b8314649757 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -12,7 +12,7 @@ typedef void* cublasHandle_t; typedef void* cudnnStatus_t; #endif #include "core/providers/nv_tensorrt_rtx/nv_includes.h" - +#include "core/session/onnxruntime_run_options_config_keys.h" #include #include "core/providers/cuda/cuda_graph.h" #include "nv_execution_provider_info.h" @@ -305,9 +305,11 @@ class NvExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; + // CUDA Graph support bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + void HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture); static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, @@ -405,15 +407,6 @@ class NvExecutionProvider : public IExecutionProvider { // Call cudaStreamSynchronize() after TRT enqueueV3() mutable bool sync_stream_after_enqueue_ = true; - CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; - // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: - // (1) memory pattern is enabled. (2) arena allocation for stream. - // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs - // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. - // [Note] We don't use PerThreadContext for now since it has issue with multithreading // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. @@ -436,14 +429,20 @@ class NvExecutionProvider : public IExecutionProvider { bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context); void ResetTensorRTContext(std::string fused_node); - void InitCUDAGraph(); - void SetGraphStream(cudaStream_t stream); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - bool IsGraphCaptured(int graph_annotation_id) const; - Status ReplayGraph(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); + // CUDA Graph management + void SetCudaGraphStream(cudaStream_t stream) { cuda_graph_.SetStream(stream); } + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + void SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id); + CudaGraphAnnotation_t GetCurrentGraphAnnotationId() const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag); + void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id); + void ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id); + void DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); private: cudnnHandle_t external_cudnn_handle_ = nullptr; @@ -466,13 +465,18 @@ class NvExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ is put under PerThreadContext. // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph instance is enough (no need to maintain one CUDAGraph instance per TRT subgraph) CUDAGraph cuda_graph_; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; + // Current graph annotation ID for this run + CudaGraphAnnotation_t current_graph_annotation_id_ = kCudaGraphAnnotationDefault; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. + // https://github.com/NVIDIA/TensorRT/blob/main/samples/common/sampleInference.cpp#L1258-L1291 Based on the trtexec code }; using PerThreadContextMap = std::unordered_map>; @@ -606,11 +610,6 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map& output_map, std::vector& node_compute_funcs); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); - /** * Get the pointer to the IBuilder instance. * This function only creates the instance at the first time it's being called." From c9c23b0e55d981ac3feef9e759a6edbd32e30fc5 Mon Sep 17 00:00:00 2001 From: Mike Hsu Date: Wed, 27 Aug 2025 10:19:08 -0700 Subject: [PATCH 14/22] [QNN-EP] Enable einsum with QK equations for QNN. (#25861) ### Description Enable einsum op with QK equations for attention in QNN EP. ### Motivation and Context Current einsum op in QNN doesn't support equations with capital alphabets. Loose this constraint to allow more usecases. Signed-off-by: Mu-Chein Hsu --- .../builder/opbuilder/einsum_op_builder.cc | 6 +- .../test/providers/qnn/einsum_op_test.cc | 153 ++++++++++++++++++ 2 files changed, 156 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index a22a331b2453f..0925d8e1a6062 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -48,13 +48,13 @@ std::optional ParseEquation(std::string_view equation_string) { if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) { return std::nullopt; } - if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } - if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } - if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } return std::make_tuple(term_1, term_2, result); diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index 11a3d5a083aab..fefd682de203e 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -150,6 +150,32 @@ TEST_F(QnnCPUBackendTests, EinsumRank2) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank3MatMul) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hij,hjk->hik", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { const std::vector shape0{3, 4, 5, 6}; const std::vector shape1{3, 4, 6, 5}; @@ -189,6 +215,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 6}; + const std::vector shape1{2, 3, 5, 6}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { const std::vector shape0{1, 7, 1, 7}; const std::vector shape1{1, 9, 1, 7}; @@ -273,6 +312,60 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bid,bjd->bij", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY_QK) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bQd,bKd->bQK", + /*tolerance=*/1e-2f); +} + +// The value pair (65.1049271, 65.0625076) at index #51 don't match, which is -0.0424194 from 65.1049 +// Disable this Rank3 test on HTP since it has accuracy issue. +TEST_F(QnnHTPBackendTests, DISABLED_EinsumRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/1e-2f); +} + TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) { const std::vector shape0{1, 3, 1, 7}; const std::vector shape1{1, 7, 1, 3}; @@ -365,6 +458,66 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) { /*tolerance=*/QDQTolerance()); } +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bid,bjd->bij", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY_QK) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bQd,bKd->bQK", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hij,hjk->hik", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/QDQTolerance()); +} + TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) { const std::vector shape0{1, 3, 1, 7}; const std::vector shape1{1, 7, 1, 3}; From 568ad20287a8e0c64552bf2ff7dd2550603c49c7 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 27 Aug 2025 10:26:36 -0700 Subject: [PATCH 15/22] Fix a long standing bug on file memory mapping on windows. (#25833) ### Description While memory profiling some models I noticed multiple file mapping failures. `WindowsEnv::MapFileIntoMemory()` While it properly checks for the mapping offset to be granularity aligned, it calculates it as page aligned. Also, while saving external tensors we do not need to align big tensors to windows granularity or anything that is platform dependent. Set it to 4096 for all platforms. Granularity matters only for calculating mapping address. ### Motivation and Context Multiple failures for file mapping for certain models. This saves some hundreds of Mbs for some models. --- .../core/graph/model_saving_options.h | 26 +++++++---------- .../framework/tensor_external_data_info.cc | 4 +-- .../framework/tensor_external_data_info.h | 10 +++---- onnxruntime/core/graph/graph.cc | 8 ++--- onnxruntime/core/platform/windows/env.cc | 29 +++++++++---------- .../save_model_with_external_initializers.cc | 2 +- onnxruntime/test/platform/file_io_test.cc | 29 +++---------------- 7 files changed, 39 insertions(+), 69 deletions(-) diff --git a/include/onnxruntime/core/graph/model_saving_options.h b/include/onnxruntime/core/graph/model_saving_options.h index 6c041ec96a035..06c1b1ac6475f 100644 --- a/include/onnxruntime/core/graph/model_saving_options.h +++ b/include/onnxruntime/core/graph/model_saving_options.h @@ -9,36 +9,30 @@ class PrepackedWeightsForGraph; // These options affect how the model initializers are written to the external file. // This includes options to align external initializer offset. -// For models running on CPU, ORT will try to use mmap to load external -// initializers. To use mmap, external initializer need to be offset aligned. +// ORT will try to use mmap to load external initializers. +// // ORT saves external initializers into single data file, each initializer is // accessed with offset(start position of initializer) and length(byte length of -// initializer) of the data file. To use mmap, each offset need to be aligned -// which means offset need to divisible by allocation granularity(64KB for -// windows and 4K for other OSes). With align_offset to true, ORT will align -// offset for large initializer when save ONNX model with external data file. +// initializer) of the data file. With align_offset to true, ORT will align +// offset for large initializer (larger than align_threshold) +// when save ONNX model with external data file. It will align then to +// on_disk_alignment value. struct ModelSavingOptions { explicit ModelSavingOptions(size_t size_threshold) : initializer_size_threshold(size_threshold) {} // Minimal initializer size in bytes to be externalized on disk size_t initializer_size_threshold; - // Offset will always be page aligned and allocation granularity aligned for - // mmap support. This is done by padding previous tensor data with zeros - // keeping same length. + // Offset will always be aligned for mmap support. + // This is done by padding previous tensor data with zeros keeping same length. bool align_offset = false; // Alignment threshold for size of data. // Having a low threshold will waste file space for small initializers. // Only when tensor's data size is > the page_align_threshold it will be force // aligned. Default to 1MB. int64_t align_threshold = 1048576; - // The allocation Granularity for mmap() support. - // Typically 64KB for Windows & 4KB for other OSes. Default to 64KB. -#ifdef _WIN32 - int64_t allocation_granularity = 65536; -#else - int64_t allocation_granularity = 4096; -#endif + // Alignment factor for big tensors (bigger than align_threshold). Defaults to 4K. + int64_t on_disk_alignment = 4096; // Force embed all external initializer into the Onnx file // Used for EPContext model generation while some nodes fallback on CPU which has external data dependency bool force_embed_external_ini = false; diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index 971851db62437..d7f5b23d56c70 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -107,7 +107,7 @@ void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& e std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, bool align, - int64_t align_threshold, int64_t allocation_granularity, + int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto) { size_t key_count = 0; for (const auto& key : blob_keys) { @@ -120,7 +120,7 @@ std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const auto size_in_bytes = prepacked_weights->buffer_sizes_[i]; if (align && static_cast(size_in_bytes) > align_threshold) { // return early on error - if (!AlignAndPad(os, allocation_granularity, external_offset)) { + if (!AlignAndPad(os, on_disk_alignment, external_offset)) { return os; } } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 2de1e01f381ec..784b3f352a78e 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -41,15 +41,13 @@ class ExternalDataInfo { size_t tensor_bytes_size, ::ONNX_NAMESPACE::TensorProto& proto); - // Pads the output with zeros according to the specified allocation_granularity + // Pads the output with zeros according to the specified alignment_factor // It updates external_offset for alignment. // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| - static std::ostream& AlignAndPad(std::ostream& stream, int64_t allocation_granularity, int64_t& external_offset) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), allocation_granularity); + static std::ostream& AlignAndPad(std::ostream& stream, int64_t alignment_factor, int64_t& external_offset) { // Align to the next page or alloc granularity boundary SafeInt safe_external_offset = external_offset; int64_t new_external_offset = ((safe_external_offset + alignment_factor - 1) / alignment_factor) * @@ -66,7 +64,7 @@ class ExternalDataInfo { static std::ostream& WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, - bool align, int64_t align_threshold, int64_t allocation_granularity, + bool align, int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e4f8cd6df678e..0a228176175eb 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4536,14 +4536,14 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( continue; } - // update external_offset for alignment + // update external_offset for alignment (if enabled) // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| if (model_saving_options.align_offset && static_cast(tensor_bytes_size) > model_saving_options.align_threshold) { - ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.allocation_granularity, + ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.on_disk_alignment, external_offset), "Failed writing external data to: ", model_external_file_path); } @@ -4576,7 +4576,7 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( auto& os = ExternalDataInfo::WritePrepackedToFileAndAddToProto( *prepacked_weights_for_graph_, blob_keys_to_external_data, model_saving_options.align_offset, model_saving_options.align_threshold, - model_saving_options.allocation_granularity, + model_saving_options.on_disk_alignment, external_stream, external_offset, *output_proto); ORT_RETURN_IF_NOT(os.good(), "Failed to write pre-packed blobs to external file"); } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 36c6b54a1fce0..aa237fc6441b2 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include "core/common/logging/logging.h" #include "core/common/narrow.h" +#include "core/common/safeint.h" #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" @@ -439,30 +440,28 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); - static const DWORD page_size = sysinfo.dwPageSize; static const DWORD allocation_granularity = sysinfo.dwAllocationGranularity; - const FileOffsetType offset_to_page = offset % static_cast(page_size); - const size_t mapped_length = length + static_cast(offset_to_page); - const FileOffsetType mapped_offset = offset - offset_to_page; - if (mapped_offset % allocation_granularity != 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "mapped offset must be a multiple of the allocation granularity", - " , mapped_offset = ", mapped_offset, - " , allocation_granularity = ", allocation_granularity, - " , errcode = ", error_code, - " - ", std::system_category().message(error_code)); - } + const FileOffsetType offset_to_granularity = offset % static_cast(allocation_granularity); + const SIZE_T mapped_length = SafeInt(offset_to_granularity) + length; + const FileOffsetType mapped_offset = offset - offset_to_granularity; + assert((mapped_offset % allocation_granularity) == 0); void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, static_cast((mapped_offset >> 32) & 0xFFFFFFFF), static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); - GSL_SUPPRESS(r.11) + + if (mapped_base == nullptr) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "MapViewOfFile ", ToUTF8String(Basename(file_path)), + " fail, errcode = ", error_code, + " - ", std::system_category().message(error_code)); + } mapped_memory = - MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, + MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_granularity, [mapped_base](void*) { UnmapFile(mapped_base); }}; diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 98874874d50e9..e70d870ef6988 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -84,7 +84,7 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, size_t tensor_offset; std::stringstream stream(entry.value()); stream >> tensor_offset; - ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.allocation_granularity == 0, + ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.on_disk_alignment == 0, "tensor offset not align"); } } diff --git a/onnxruntime/test/platform/file_io_test.cc b/onnxruntime/test/platform/file_io_test.cc index ccc703716844f..a1a863d2442d1 100644 --- a/onnxruntime/test/platform/file_io_test.cc +++ b/onnxruntime/test/platform/file_io_test.cc @@ -157,7 +157,6 @@ TEST(FileIoTest, MapFileIntoMemory) { SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); static const auto page_size = sysinfo.dwPageSize; - static const auto allocation_granularity = sysinfo.dwAllocationGranularity; ASSERT_GT(page_size, static_cast(0)); TempFilePath tmp(ORT_TSTR("map_file_test_")); @@ -167,21 +166,10 @@ TEST(FileIoTest, MapFileIntoMemory) { const auto offsets_and_lengths = GenerateValidOffsetLengthPairs( 0, expected_data.size(), page_size / 10); - for (const auto& offset_and_length : offsets_and_lengths) { - const auto offset = offset_and_length.first; - const auto length = offset_and_length.second; - - // The offset must be a multiple of the allocation granularity - if (offset % allocation_granularity != 0) { - continue; - } - + for (const auto& [offset, length] : offsets_and_lengths) { Env::MappedMemoryPtr mapped_memory{}; - auto status = Env::Default().MapFileIntoMemory( - tmp.path.c_str(), offset, length, mapped_memory); - ASSERT_TRUE(status.IsOK()) - << "MapFileIntoMemory failed for offset " << offset << " and length " << length - << " with error: " << status.ErrorMessage(); + ASSERT_STATUS_OK(Env::Default().MapFileIntoMemory( + tmp.path.c_str(), offset, length, mapped_memory)); auto mapped_span = gsl::make_span(mapped_memory.get(), length); @@ -190,20 +178,11 @@ TEST(FileIoTest, MapFileIntoMemory) { ASSERT_TRUE(SpanEq(mapped_span, expected_data_span)); } - { - Env::MappedMemoryPtr mapped_memory{}; - - // invalid - offset is not a multiple of the allocation granularity - ASSERT_FALSE(Env::Default().MapFileIntoMemory( - tmp.path.c_str(), allocation_granularity * 3 / 2, page_size / 10, mapped_memory) - .IsOK()); - } - { Env::MappedMemoryPtr mapped_memory{}; // invalid - negative offset - ASSERT_FALSE(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory).IsOK()); + ASSERT_STATUS_NOT_OK(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory)); } } #endif From c9ec1da44fbd9f9bbb4aa22fa5785465d92c34ce Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 27 Aug 2025 10:27:14 -0700 Subject: [PATCH 16/22] Add default constructor to Ort::Status. (#25860) ### Description Fix packaging pipelines ### Motivation and Context During CIs and local builds Ort::Status() gets inherited from the base due to using directives, however, that does not work for packaging pipelines. Having default ctor is important for storing Status in containers if needed. --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 4 +--- include/onnxruntime/core/session/onnxruntime_cxx_inline.h | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 2f4fd36c8115f..c39e27088e8bc 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -725,9 +725,7 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { - using Base = detail::Base; - using Base::Base; - + Status() = default; // Same as with std::nullptr_t. But can be used in re-sizable containers and represent success. explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&); ///< Creates status instance out of exception diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 73200d8852223..d0089726812a3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -823,7 +823,7 @@ inline Status Env::CopyTensors(const std::vector& src_tensors, return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT); } if (src_tensors.empty()) { - return Status(); + return Status(nullptr); } const OrtValue* const* src_tensors_ptr = reinterpret_cast(src_tensors.data()); From 0ba4d29613d583531b9c635b6dafa22fc60671b0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Aug 2025 10:53:49 -0700 Subject: [PATCH 17/22] Bump actions/setup-java from 4 to 5 (#25840) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/setup-java](https://github.com/actions/setup-java) from 4 to 5.
Release notes

Sourced from actions/setup-java's releases.

v5.0.0

What's Changed

Breaking Changes

Make sure your runner is updated to this version or newer to use this release. v2.327.1 Release Notes

Dependency Upgrades

Bug Fixes

New Contributors

Full Changelog: https://github.com/actions/setup-java/compare/v4...v5.0.0

v4.7.1

What's Changed

Documentation changes

Dependency updates:

Full Changelog: https://github.com/actions/setup-java/compare/v4...v4.7.1

v4.7.0

What's Changed

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/setup-java&package-manager=github_actions&previous-version=4&new-version=5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/android.yml | 4 ++-- .github/workflows/codeql.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/windows_cuda.yml | 4 ++-- .github/workflows/windows_dml.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 ++-- .github/workflows/windows_webgpu.yml | 4 ++-- .github/workflows/windows_x64_debug_build_x64_debug.yml | 2 +- .github/workflows/windows_x64_release_build_x64_release.yml | 2 +- ...neric_interface_build_x64_release_ep_generic_interface.yml | 2 +- .../windows_x64_release_vitisai_build_x64_release.yml | 2 +- .github/workflows/windows_x64_release_xnnpack.yml | 2 +- .github/workflows/windows_x86.yml | 2 +- 13 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f83847c3d6df4..feba60d58b9a2 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -124,7 +124,7 @@ jobs: - uses: actions/checkout@v5 - name: Use jdk 17 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -206,7 +206,7 @@ jobs: - uses: actions/checkout@v5 - name: Use jdk 17 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index efe580c1b3b0c..8630abba416e1 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -55,7 +55,7 @@ jobs: # Setup Java to use a version that is not too old for the project - if: ${{ matrix.language == 'java' }} name: Setup Java 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'microsoft' diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index c107535473786..742069542ecb9 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Set up JDK 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'adopt' diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 92429de408cda..fb5ad36d2ab1d 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -60,7 +60,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -172,7 +172,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index fe45f4af0cd46..c3db42fdb15a3 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -51,7 +51,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 0aef550576d21..3a176ac41ebcf 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -65,7 +65,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -177,7 +177,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 730dafa0ddc43..1924e5d617679 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -61,7 +61,7 @@ jobs: node-version: "20.x" - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: "temurin" java-version: "17" @@ -249,7 +249,7 @@ jobs: node-version: "20.x" - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: "temurin" java-version: "17" diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index c31027ce32fcd..1f71d3020bce2 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index 8a9d79aa4b8d1..dec37b8f5620d 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index a3639a9bcbfac..978f3e345141f 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index e356d9ad15c99..983bfaf411983 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index fa08c7e47cd36..c9d76ee450015 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index f7774a70dbd43..762762ebcd435 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -44,7 +44,7 @@ jobs: architecture: x86 #Add architecture - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' From 3cadbdb495761a6a54845b178f9bdb811a2c8bde Mon Sep 17 00:00:00 2001 From: adrastogi Date: Wed, 27 Aug 2025 13:15:26 -0700 Subject: [PATCH 18/22] Add API for precompiled model compatibility check using just the compat info (#25841) ### Description This PR adds a new API that applications can use to verify compatibility of a precompiled model with the underlying system, using only the compatibility info string from the model's metadata. ### Motivation and Context - This is a feature to enable apps to check compatibility of a precompiled model without necessarily having the model locally on the device. This enables precompiled models to be stored remotely and downloaded once the application has been able to confirm the validity of a given model with EPs on the device. ### Testing - New unit tests pass - For regression testing, built a private version of WinML + AMD NPU EP with these changes. Ran the Cpp Selfcontained Desktop sample successfully; ran with compilation and also re-ran using the already-compiled model to verify that session initialization continued to work as expected. --------- Co-authored-by: Aditya Rastogi --- .../core/session/onnxruntime_c_api.h | 28 ++++++ .../core/session/onnxruntime_ep_c_api.h | 30 +++--- onnxruntime/core/session/onnxruntime_c_api.cc | 76 ++++++++++++++-- onnxruntime/core/session/ort_apis.h | 7 ++ .../session/plugin_ep/ep_factory_internal.h | 6 +- .../plugin_ep/ep_factory_internal_impl.h | 9 +- .../ep_plugin_provider_interfaces.cc | 9 +- .../plugin_ep/forward_to_factory_impl.h | 5 +- .../test/framework/ep_compatibility_test.cc | 91 +++++++++++++++++++ 9 files changed, 231 insertions(+), 30 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index beda0e2a9c0d0..9ae6174817b7c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -902,6 +902,16 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ +/* + * Public enum for compiled model compatibility across EPs. + */ +typedef enum OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, + OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, + OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, + OrtCompiledModelCompatibility_EP_UNSUPPORTED, +} OrtCompiledModelCompatibility; + struct OrtApi { /// \name OrtStatus /// @{ @@ -6480,6 +6490,24 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + + /** \brief Validate a compiled model's compatibility information for one or more EP devices. + * + * \param[in] ep_devices The EP devices to validate against (e.g., from GetEpDevices). + * All devices must belong to the same execution provider. + * \param[in] num_ep_devices The number of EP devices provided. + * \param[in] compatibility_info The compatibility info string produced when the model was compiled. + * \param[out] out_status The resulting compatibility status for the EP devices. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 620cb5fcf13cc..975f6b453a88d 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -482,18 +482,6 @@ typedef enum OrtEpDataLayout { OrtEpDataLayout_Default = OrtEpDataLayout_NCHW, } OrtEpDataLayout; -/** - * \brief Enumeration describing the compatibility state of a compiled model relative to an execution provider. - * - * \since Version 1.23. - */ -typedef enum OrtCompiledModelCompatibility { - OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, - OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, - OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, - OrtCompiledModelCompatibility_EP_UNSUPPORTED, -} OrtCompiledModelCompatibility; - /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -901,20 +889,28 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Validate the compatibility of a compiled model with the execution provider. + /** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices. + * + * Given a compatibility info string produced during model compilation, the EP factory should determine whether the + * compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided + * must belong to the same execution provider instance that this factory creates. * - * This function validates if a model produced with the supplied compatibility info string is supported by the underlying EP. - * The EP should check if a compiled model is compatible with the EP and set the model_compatibility parameter accordingly. + * The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when + * evaluating compatibility and set `model_compatibility` accordingly. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] compatibility_info The compatibility information string that will be used - * \param[out] model_compatibility OrtCompiledModelCompatibility enum value describing the compatibility of the model with the EP. + * \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP. + * \param[in] num_devices Number of entries in `devices`. + * \param[in] compatibility_info The compatibility information string produced when the model was compiled. + * \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1f491bc788870..ad0a1ad137f06 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3423,25 +3423,86 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, API_IMPL_END } +// Validate compiled model compatibility info for specific EP device(s) +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status) { + API_IMPL_BEGIN + if (ep_devices == nullptr || num_ep_devices == 0 || compatibility_info == nullptr || out_status == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to GetModelCompatibilityForEpDevices."); + } + + // Validate inputs and ensure all devices belong to the same EP/factory + const OrtEpFactory* first_factory = nullptr; + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_devices contains a null entry."); + } + const OrtEpFactory* f = ep_devices[i]->GetMutableFactory(); + if (i == 0) { + first_factory = f; + } else if (f != first_factory) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All ep_devices must be from the same execution provider."); + } + } + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + OrtStatus* ort_status = nullptr; + OrtEpFactory* factory = ep_devices[0]->GetMutableFactory(); + if (factory && factory->ValidateCompiledModelCompatibilityInfo) { + // collect hardware devices corresponding to the ep_devices + InlinedVector hardware_devices; + hardware_devices.reserve(num_ep_devices); + for (size_t i = 0; i < num_ep_devices; ++i) { + hardware_devices.push_back(ep_devices[i]->device); + } + ort_status = factory->ValidateCompiledModelCompatibilityInfo(factory, + hardware_devices.data(), + hardware_devices.size(), + compatibility_info, + &status); + } + if (ort_status != nullptr) { + return ToOrtStatus(ToStatusAndRelease(ort_status)); + } + + *out_status = status; + return nullptr; + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RegisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "UnregisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetEpDevices, _In_ const OrtEnv* /*env*/, _Outptr_ const OrtEpDevice* const** /*ep_devices*/, _Out_ size_t* /*num_ep_devices*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetEpDevices is not supported in a minimal build."); + API_IMPL_END +} + +// Minimal build stub for GetModelCompatibilityForEpDevices to satisfy symbol references from the API table +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, + _In_ size_t /*num_ep_devices*/, + _In_ const char* /*compatibility_info*/, + _Out_ OrtCompiledModelCompatibility* /*out_status*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetModelCompatibilityForEpDevices is not supported in a minimal build."); API_IMPL_END } @@ -3453,7 +3514,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* /*ep_option_vals*/, size_t /*num_ep_options*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionOptionsAppendExecutionProvider_V2 is not supported in a minimal build."); API_IMPL_END } @@ -3466,7 +3527,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, _In_ size_t /*num_values*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForInputs is not supported in a minimal build."); API_IMPL_END } @@ -3474,7 +3535,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_ OrtSyncStream** /*ort_stream*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateSyncStreamForEpDevice is not supported in a minimal build."); API_IMPL_END } @@ -3493,7 +3554,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, _In_opt_ OrtSyncStream* /*stream*/, _In_ size_t /*num_tensors*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CopyTensors is not supported in a minimal build."); API_IMPL_END } @@ -4108,6 +4169,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::CopyTensors, &OrtApis::Graph_GetModelMetadata, + &OrtApis::GetModelCompatibilityForEpDevices, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b3b0036c68247..e62149d04a16c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -636,6 +636,13 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + +// EP Compatibility Info APIs +ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 23e5e95af2903..093bfce462d32 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -80,9 +80,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } - OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { - return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } // Function ORT calls to release an EP instance. diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 6c55730d83979..f29154d19c53c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,8 +62,13 @@ class EpFactoryInternalImpl { return false; } - virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, - _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo( + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); ORT_UNUSED_PARAMETER(compatibility_info); // Default implementation: mark as not applicable *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 3bfca62a4d011..c8829423fbe26 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -668,8 +668,15 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std // Plugin EP did not provide an implementation of this function, so we call a default implementation. return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); } - // Delegate to the EP factory's validation method + // Delegate to the EP factory's validation method, passing hardware devices derived from our ep_devices_ + std::vector hardware_devices; + hardware_devices.reserve(ep_devices_.size()); + for (const auto* ep_device : ep_devices_) { + hardware_devices.push_back(ep_device->device); + } ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + hardware_devices.data(), + hardware_devices.size(), compatibility_info.c_str(), &model_compatibility))); return Status::OK(); diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 29793b503c9d1..2cceb1d08d536 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -46,9 +46,12 @@ struct ForwardToFactoryImpl { } static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + size_t num_devices, const char* compatibility_info, OrtCompiledModelCompatibility* model_compatibility) noexcept { - return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(devices, num_devices, + compatibility_info, model_compatibility); } static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index be97cf2620881..ee82d4683ab73 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -408,3 +408,94 @@ TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { EXPECT_TRUE(has_config); EXPECT_EQ(config_value, "0"); } + +// ----------------------------- +// C API unit tests +// ----------------------------- + +namespace { + +// Helper to create an OrtEnv and fetch a CPU EP device pointer via the C API. +// Returns a pair of (env, cpu_device). Caller releases env via api->ReleaseEnv. +static std::pair CreateEnvAndGetCpuEpDevice(const OrtApi* api) { + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpCompatCapiTest", &env)); + EXPECT_NE(env, nullptr); + + const OrtEpDevice* const* devices = nullptr; + size_t num_devices = 0; + EXPECT_EQ(nullptr, api->GetEpDevices(env, &devices, &num_devices)); + EXPECT_GT(num_devices, 0u); + + const OrtEpDevice* cpu_device = nullptr; + for (size_t i = 0; i < num_devices; ++i) { + const char* name = api->EpDevice_EpName(devices[i]); + if (name && std::string(name) == "CPUExecutionProvider") { + cpu_device = devices[i]; + break; + } + } + + // Fallback: just pick the first device if CPU wasn't found (environment-dependent builds). + if (!cpu_device && num_devices > 0) { + cpu_device = devices[0]; + } + + EXPECT_NE(cpu_device, nullptr); + return {env, cpu_device}; +} + +} // namespace + +TEST(EpCompatibilityCapiTest, InvalidArguments) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtCompiledModelCompatibility out_status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + + // ep_devices == nullptr + OrtStatus* st = api->GetModelCompatibilityForEpDevices(nullptr, 0, "info", &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // Prepare a valid device + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + // compatibility_info == nullptr + const OrtEpDevice* devices1[] = {device}; + st = api->GetModelCompatibilityForEpDevices(devices1, 1, nullptr, &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // out_status == nullptr + st = api->GetModelCompatibilityForEpDevices(devices1, 1, "some-info", nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + OrtCompiledModelCompatibility out_status = static_cast(-1); + const OrtEpDevice* devices2[] = {device}; + OrtStatus* st = api->GetModelCompatibilityForEpDevices(devices2, 1, "arbitrary-compat-string", &out_status); + ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : ""); + + // For providers that don't implement validation, API should return EP_NOT_APPLICABLE. + EXPECT_EQ(out_status, OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} From 7af42b8f7c1a8e1bcaeb45873f1ef93dd1f02048 Mon Sep 17 00:00:00 2001 From: shiyi Date: Thu, 28 Aug 2025 08:48:45 +0800 Subject: [PATCH 19/22] [WebNN] Fix the op support limit for batchNormalization (#25856) ### Description According to the [WebNN spec](https://www.w3.org/TR/webnn/#api-mlgraphbuilder-batchnorm), the batchNorm should have input names "mean" and "variance" instead of "input_mean" and "input_var". ### Motivation and Context This issue causes any BatchNorm with mean/variance inputs to fall back to wasm. --- onnxruntime/core/providers/webnn/builders/map_info.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 1c30fed7a7916..ffb8091e3ecbc 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -200,7 +200,7 @@ const std::unordered_map op_inputs_map = { {"DequantizeLinear", {"dequantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, {"InstanceNormalization", {"instanceNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, {"GRU", {"gru", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}}}}, - {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "input_mean"}, {4, "input_var"}}}}, + {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "mean"}, {4, "variance"}}}}, {"LSTM", {"lstm", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}, {6, "initialCellState"}, {7, "peepholeWeight"}}}}, }; } // namespace webnn From 1d07e94c0ced1b1b73abb5431c16ab128e536adc Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 28 Aug 2025 08:50:50 +0800 Subject: [PATCH 20/22] [WebNN] Support Round op (#25810) --- js/web/docs/webnn-operators.md | 1 + .../webnn/builders/impl/unary_op_builder.cc | 39 +++---------------- .../core/providers/webnn/builders/map_info.h | 1 + .../webnn/builders/op_builder_factory.cc | 1 + 4 files changed, 9 insertions(+), 33 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index d9a030f320c6c..0fffe99ec4f78 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -80,6 +80,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | | | QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | The shape of x_scale should be a subsample of the shape of input | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | | +| Round | ai.onnx(11-21, 22+) | roundEven | | | ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | Input 'axes' if present should be a constant | | ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | Input 'axes' if present should be a constant | | ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum | Input 'axes' if present should be a constant | diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 91af452c64efd..87d91178f07bc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -27,42 +27,14 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - if (op_type == "Abs") { - output = model_builder.GetBuilder().call("abs", input, options); - } else if (op_type == "Ceil") { - output = model_builder.GetBuilder().call("ceil", input, options); - } else if (op_type == "Cos") { - output = model_builder.GetBuilder().call("cos", input, options); - } else if (op_type == "Erf") { - output = model_builder.GetBuilder().call("erf", input, options); - } else if (op_type == "Exp") { - output = model_builder.GetBuilder().call("exp", input, options); - } else if (op_type == "Floor") { - output = model_builder.GetBuilder().call("floor", input, options); - } else if (op_type == "Identity") { - output = model_builder.GetBuilder().call("identity", input, options); - } else if (op_type == "Log") { - output = model_builder.GetBuilder().call("log", input, options); - } else if (op_type == "Neg") { - output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Reciprocal") { - output = model_builder.GetBuilder().call("reciprocal", input, options); - } else if (op_type == "Sign") { - output = model_builder.GetBuilder().call("sign", input, options); - } else if (op_type == "Sin") { - output = model_builder.GetBuilder().call("sin", input, options); - } else if (op_type == "Sqrt") { - output = model_builder.GetBuilder().call("sqrt", input, options); - } else if (op_type == "Tan") { - output = model_builder.GetBuilder().call("tan", input, options); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + const std::string_view webnn_op_type = GetWebNNOpType(op_type); + ORT_RETURN_IF(webnn_op_type.empty(), "Cannot get WebNN op type"); + + emscripten::val output = model_builder.GetBuilder().call( + std::string(webnn_op_type).c_str(), input, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -84,6 +56,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Log", "Neg", "Reciprocal", + "Round", "Sign", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index ffb8091e3ecbc..590614edf851c 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -174,6 +174,7 @@ const std::unordered_map op_inputs_map = { {"Greater", {"greater", {{0, "a"}, {1, "b"}}}}, {"Reciprocal", {"reciprocal", {{0, "input"}}}}, {"ReduceMean", {"reduceMean", {{0, "input"}}}}, + {"Round", {"roundEven", {{0, "input"}}}}, {"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}}, {"HardSigmoid", {"hardSigmoid", {{0, "input"}}}}, {"ReduceProd", {"reduceProduct", {{0, "input"}}}}, diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 67e3123ee7af6..c761213ef4dc2 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Round", op_registrations); CreateUnaryOpBuilder("Sign", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); From d4e31dc6613074b338aff2166d0162ca583917ff Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 27 Aug 2025 18:09:37 -0700 Subject: [PATCH 21/22] Improve SimplifiedLayerNorm by using same techniques as SkipSimplifiedLayerNorm (#25850) ### Description Use similar shaders as SkipSimplifiedLayerNorm in SimplifiedLayerNorm, to fix the performance issues with SimplifiedLayerNorm. ### Motivation and Context Prior to this change, generation in Bitnet was bottlenecked on SimplifiedLayerNorm image with this change performance has now improved to match SkipSimplifiedLayerNorm image --- .../core/providers/webgpu/nn/layer_norm.cc | 160 ++++++++++++++---- .../core/providers/webgpu/nn/layer_norm.h | 13 +- 2 files changed, 132 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 5ac4541e8323e..7d4ae8c2197ff 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -26,7 +26,7 @@ static TensorShape GetOverrideShape(const TensorShape& shape, int components) { } Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("scale", ShaderUsage::UseUniform); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); @@ -39,35 +39,113 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("inv_std_dev_output", ShaderUsage::None); } - int components = x.NumComponents(); - std::string bias = (has_bias_) ? " + bias[j]" : ""; - std::string simpl1 = (simplified_) ? "" : " - mean * mean"; - std::string simpl2 = (simplified_) ? "" : " - mean"; - - shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") - << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") - << "let offset = global_idx * uniforms.norm_size_vectorized;\n" - << "var mean_vector = f32_val_t(0);\n" - << "var mean_square_vector = f32_val_t(0);\n" - << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" - << " let value = f32_val_t(x[h + offset]);\n" - << " mean_vector += value;\n" - << " mean_square_vector += value * value;\n" - << "}\n" - << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" - << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" - << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" - << " let f32input = f32_val_t(x[j + offset]);\n" - << " let f32scale = f32_val_t(scale[j]);\n" - << " y[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" - << "}\n"; - if (has_mean_output_) { - shader.MainFunctionBody() << "mean_output[global_idx] = mean;\n"; - } - if (has_inv_std_dev_output_) { - shader.MainFunctionBody() << "inv_std_dev_output[global_idx] = inv_std_dev;\n"; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- x_element_t(mean) "; + + if (split_norm_dim_) { + shader.AdditionalImplementation() + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << " var sum_vec4 = vec4(0);\n" + << " var sum_squared_vec4 = vec4(0);\n" + << " var cur_input = x_value_t(0);\n" + << " for (var i: u32 = 0; i < uniforms.norm_size / (workgroup_size_x * 4); i++) {\n" + << " let input_offset = i * workgroup_size_x + local_idx;\n" + << " let input_value = x[input_offset];\n" + << " if (i == workgroup_idx) {\n" + << " cur_input = input_value;\n" + << " }\n" + << " let f32_value = vec4(input_value);\n" + << " sum_vec4 += f32_value;\n" + << " sum_squared_vec4 += f32_value * f32_value;\n" + << " }\n" + << " var sum = " << SumVector("sum_vec4", 4) << ";\n" + << " var sum_squared = " << SumVector("sum_squared_vec4", 4) << ";\n" + << " sum_shared[local_idx] = sum;\n" + << " sum_squared_shared[local_idx] = sum_squared;\n" + << " workgroupBarrier();\n" + << " var reduce_size : u32 = workgroup_size_x;\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (local_idx < curr_size) {\n" + << " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n" + << " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " let mean = sum_shared[0] / f32(uniforms.norm_size);\n" + << " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << " let offset = workgroup_idx * workgroup_size_x + local_idx;\n" + << " y[offset] = ((cur_input " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset]" << (has_bias_ ? " + bias[offset] " : "") << ");\n"; + + if (has_mean_output_) { + shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n" + << " mean_output[global_idx / uniforms.norm_size] = mean;\n" + << " }\n"; + } + if (has_inv_std_dev_output_) { + shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n" + << " inv_std_dev_output[global_idx / uniforms.norm_size] = inv_std_dev;\n" + << " }\n"; + } + } else { + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[offset1d + i] " : ""; + + shader.AdditionalImplementation() + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let norm_size_vectorized: u32 = uniforms.norm_size / uniforms.components;\n" + << "var stride = norm_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * norm_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "sum_shared[ix] = f32_val_t(0);\n" + << "sum_squared_shared[ix] = f32_val_t(0);\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = norm_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let input_value = x[offset + i];\n" + << " y[offset + i] = input_value;\n" + << " let f32_value = f32_val_t(input_value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " y[offset + i] = (y[offset + i] " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset1d + i]" << bias << ";\n" + << "};\n"; + + if (has_mean_output_) { + shader.MainFunctionBody() << "if (ix == 0) {\n" + << " mean_output[iy] = mean;\n" + << "}\n"; + } + if (has_inv_std_dev_output_) { + shader.MainFunctionBody() << "if (ix == 0) {\n" + << " inv_std_dev_output[iy] = inv_std_dev;\n" + << "}\n"; + } } return Status::OK(); @@ -81,8 +159,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const auto x_shape = x->Shape(); - const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); @@ -116,14 +192,19 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return Status::OK(); } - LayerNormProgram program{bias != nullptr, is_fp16, simplified, mean != nullptr, inv_std_dev != nullptr}; + // Check if we should use split norm dimension optimization + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + LayerNormProgram program{bias != nullptr, simplified, mean != nullptr, inv_std_dev != nullptr, split_norm_dim}; - program.CacheHint(components, simplified) + program.CacheHint(components, simplified, split_norm_dim) .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}}) .AddInputs( {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) - .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(components)}, + }) .AddUniformVariables({ {static_cast(norm_count)}, }) @@ -137,6 +218,15 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex {static_cast(epsilon_)}, }); + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + if (bias != nullptr) { program.AddInput( {bias, ProgramTensorMetadataDependency::Type, GetOverrideShape(bias->Shape(), components), components}); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index c7cb9280a0b77..112b152d37130 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -11,28 +11,29 @@ namespace webgpu { class LayerNormProgram final : public Program { public: - LayerNormProgram(bool has_bias, bool is_fp16, bool simplified, bool has_mean_output, - bool has_inv_std_dev_output) + LayerNormProgram(bool has_bias, bool simplified, bool has_mean_output, + bool has_inv_std_dev_output, bool split_norm_dim = false) : Program{"LayerNorm"}, has_bias_{has_bias}, - is_fp16_{is_fp16}, simplified_{simplified}, has_mean_output_{has_mean_output}, - has_inv_std_dev_output_{has_inv_std_dev_output} {} + has_inv_std_dev_output_{has_inv_std_dev_output}, + split_norm_dim_{split_norm_dim} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"norm_count", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"components", ProgramUniformVariableDataType::Uint32}, + {"norm_count", ProgramUniformVariableDataType::Uint32}, {"norm_size", ProgramUniformVariableDataType::Uint32}, {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, {"epsilon", ProgramUniformVariableDataType::Float32}); private: bool has_bias_; - bool is_fp16_; bool simplified_; bool has_mean_output_; bool has_inv_std_dev_output_; + bool split_norm_dim_; }; template From 0b15200243c2522fb33a6b3d133176a0c6738a73 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Thu, 28 Aug 2025 11:00:07 +0800 Subject: [PATCH 22/22] DequantizeLinear should support non-zero zero_point when input type is int32 (#25646) ### Description This PR makes DequantizeLinear support non-zero zero_point when input data type is int32. ### Motivation and Context For WebNN use case, we have some scenarios that input data type is int32 and the zero_point is not zero for DequantizeLinear. --- .../providers/cpu/quantization/quantize_linear.cc | 6 ++---- .../providers/cpu/tensor/quantize_linear_test.cc | 13 ++++++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index c691be6ffd0e8..dbf2be74f7362 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -520,14 +520,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const T* zero_point = x_zero_point ? x_zero_point->Data() : nullptr; #if !defined(DISABLE_FLOAT8_TYPES) - if constexpr (boost::mp11::mp_contains>, - T>::value) { + if constexpr (boost::mp11::mp_contains::value) { ORT_ENFORCE(zero_point == nullptr || std::all_of(zero_point, zero_point + x_zero_point->Shape().Size(), [](T zp) { return zp == T{0}; }), - "DequantizeLinear with type int32 or float8 should have no zero point or all zero points should be 0"); + "DequantizeLinear with type float8 should have no zero point or all zero points should be 0"); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 8fdbf0060eaa0..0e17aa835028e 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -137,7 +137,7 @@ TEST(DequantizeLinearOpTest, Uint16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// scalar zero & scale with int8 +// scalar zero & scale with int32 TEST(DequantizeLinearOpTest, Int32) { OpTester test("DequantizeLinear", 10); std::vector dims{4}; @@ -147,6 +147,17 @@ TEST(DequantizeLinearOpTest, Int32) { test.Run(); } +// non-zero zero point with int32 +TEST(DequantizeLinearOpTest, Int32_Non_Zero_Zero_Point) { + OpTester test("DequantizeLinear", 10); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); + test.AddInput("x_scale", {}, {2.0f}, true); + test.AddInput("x_zero_point", {}, {1}, true); + test.AddOutput("y", dims, {-62.f, -8.f, 198.f, 252.f}); + test.Run(); +} + TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) { OpTester test("DequantizeLinear", 13); test.AddInput("x", {4}, {-30, -3, 100, 127});