diff --git a/cmake/deps.txt b/cmake/deps.txt
index e1870bf2df0cf..078a66a4c4d85 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -31,7 +31,7 @@ googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f6
googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe
json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
-microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
+microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.250325.1.zip;826c8bd47c2258ec61b8b218e031e5b33d27f761
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.19.1.zip;c5215b5697dcdfd71799f001b8c4054a6bba6b09
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 754669fffbf8d..4913d38939792 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1528,8 +1528,14 @@ endif()
onnxruntime_add_shared_library(onnxruntime_runtime_path_test_shared_library
${onnxruntime_runtime_path_test_shared_library_src})
- target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE
- onnxruntime_common cpuinfo ${CMAKE_DL_LIBS})
+ if (CMAKE_SYSTEM_NAME MATCHES "AIX")
+ target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE
+ onnxruntime_common ${CMAKE_DL_LIBS})
+ set_target_properties(onnxruntime_runtime_path_test_shared_library PROPERTIES AIX_SHARED_LIBRARY_ARCHIVE OFF)
+ else()
+ target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE
+ onnxruntime_common cpuinfo ${CMAKE_DL_LIBS})
+ endif()
target_include_directories(onnxruntime_runtime_path_test_shared_library PRIVATE ${ONNXRUNTIME_ROOT})
if(UNIX)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 92093ec5464f7..d7be243323c17 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -88,7 +88,7 @@ Do not modify directly.*
|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[11, 21]|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
-|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)|
+|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)|
|ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[11, 21]|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h
index 24cc460a17fa9..983be1f9efd5c 100644
--- a/include/onnxruntime/core/framework/allocator.h
+++ b/include/onnxruntime/core/framework/allocator.h
@@ -38,6 +38,13 @@ struct OrtArenaCfg {
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default
+ // Use CudaMemPool based arena if available (starting with cuda 11.2)
+ int use_cuda_mempool = -1;
+ // Amount of reserved memory in bytes to hold onto before trying
+ // to release memory back to the OS.
+ uint64_t cuda_mempool_release_threshold = 0;
+ // Bytes to keep on shrink for CudaMemPool, 0 is to attempt to release all, allocated space not affected.
+ size_t cuda_mempool_bytes_to_keep_on_shrink = 0;
bool IsValid() {
return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 &&
@@ -55,6 +62,9 @@ struct OrtArenaCfg {
static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes";
static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes";
static constexpr const char* MaxMem = "arena.max_mem";
+ static constexpr const char* UseCudaMemPool = "arena.use_cuda_mempool";
+ static constexpr const char* CudaMempoolReleaseThreshold = "arena.cuda_mempool_release_threshold";
+ static constexpr const char* CudaMempoolBytesToKeepOnShrink = "arena.cuda_mempool_bytes_to_keep_on_shrink";
};
static onnxruntime::common::Status FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg);
@@ -348,4 +358,13 @@ void AllocatorDefaultFree(void* p);
void* AllocatorDefaultAllocAligned(size_t size, size_t alignment);
void AllocatorDefaultFreeAligned(void* p, size_t alignment);
+class IArena : public IAllocator {
+ public:
+ using IAllocator::IAllocator;
+ virtual Status Shrink() = 0;
+ // Only implemented when IsStreamAware() returns true
+ virtual void ReleaseStreamBuffers(Stream* /*stream*/) {}
+ static IArena* SafeArenaCast(IAllocator* allocator);
+};
+
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 02915f2f1882e..d1b652229e4b6 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -6591,6 +6591,23 @@ struct OrtApi {
* \since Version 1.24
*/
ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info);
+
+ /** \brief Get all config entries from ::OrtKernelInfo.
+ *
+ * Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs.
+ * Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels.
+ *
+ * Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries
+ * during kernel construction.
+ *
+ * \param[in] info An instance of ::OrtKernelInfo.
+ * \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries.
+ * Note: the user should call OrtApi::ReleaseKeyValuePairs.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ * \since Version 1.24
+ */
+ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index d3a8856455c49..22708bbf06a3d 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base {
std::string GetNodeName() const;
Logger GetLogger() const;
+
+ KeyValuePairs GetConfigEntries() const;
};
} // namespace detail
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 8ee057f51eb20..5144418db2b58 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl::GetLogger() const {
return Logger{out};
}
+template
+inline KeyValuePairs KernelInfoImpl::GetConfigEntries() const {
+ OrtKeyValuePairs* out = nullptr;
+ Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out));
+ return KeyValuePairs{out};
+}
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
}
diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h
index f6afd84dabc5e..5ea4261840299 100644
--- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h
@@ -9,6 +9,9 @@
// Key for the execution provider version string. This should be available for all plugin EPs.
static const char* const kOrtEpDevice_EpMetadataKey_Version = "version";
+// Key for the execution provider OS driver version.
+static const char* const kOrtEpDevice_EpMetadataKey_OSDriverVersion = "os_driver_version";
+
// Prefix for execution provider compatibility information stored in model metadata.
// Used when generating EP context models to store compatibility strings for each EP.
// Full key format: "ep_compatibility_info."
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index f4ccff1d7770d..9f2f0afa61604 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -1533,7 +1533,7 @@
// // "test_adam_multiple",
// // "test_adam",
"test_add_bcast",
- // "test_add_uint8",
+ "test_add_uint8",
"test_add",
"test_and_bcast3v1d",
"test_and_bcast3v2d",
@@ -1543,37 +1543,38 @@
"test_and2d",
"test_and3d",
"test_and4d",
- "test_argmax_default_axis_example_select_last_index",
+ // tests "test_arg*_select_last_index" are excluded because WebNN spec does not support select_last_index attribute.
+ // "test_argmax_default_axis_example_select_last_index",
"test_argmax_default_axis_example",
- "test_argmax_default_axis_random_select_last_index",
+ // "test_argmax_default_axis_random_select_last_index",
"test_argmax_default_axis_random",
- "test_argmax_keepdims_example_select_last_index",
+ // "test_argmax_keepdims_example_select_last_index",
"test_argmax_keepdims_example",
- "test_argmax_keepdims_random_select_last_index",
+ // "test_argmax_keepdims_random_select_last_index",
"test_argmax_keepdims_random",
- "test_argmax_negative_axis_keepdims_example_select_last_index",
+ // "test_argmax_negative_axis_keepdims_example_select_last_index",
"test_argmax_negative_axis_keepdims_example",
- "test_argmax_negative_axis_keepdims_random_select_last_index",
+ // "test_argmax_negative_axis_keepdims_random_select_last_index",
"test_argmax_negative_axis_keepdims_random",
- "test_argmax_no_keepdims_example_select_last_index",
+ // "test_argmax_no_keepdims_example_select_last_index",
"test_argmax_no_keepdims_example",
- "test_argmax_no_keepdims_random_select_last_index",
+ // "test_argmax_no_keepdims_random_select_last_index",
"test_argmax_no_keepdims_random",
- "test_argmin_default_axis_example_select_last_index",
+ // "test_argmin_default_axis_example_select_last_index",
"test_argmin_default_axis_example",
- "test_argmin_default_axis_random_select_last_index",
+ // "test_argmin_default_axis_random_select_last_index",
"test_argmin_default_axis_random",
- "test_argmin_keepdims_example_select_last_index",
+ // "test_argmin_keepdims_example_select_last_index",
"test_argmin_keepdims_example",
- "test_argmin_keepdims_random_select_last_index",
+ // "test_argmin_keepdims_random_select_last_index",
"test_argmin_keepdims_random",
- "test_argmin_negative_axis_keepdims_example_select_last_index",
+ // "test_argmin_negative_axis_keepdims_example_select_last_index",
"test_argmin_negative_axis_keepdims_example",
- "test_argmin_negative_axis_keepdims_random_select_last_index",
+ // "test_argmin_negative_axis_keepdims_random_select_last_index",
"test_argmin_negative_axis_keepdims_random",
- "test_argmin_no_keepdims_example_select_last_index",
+ // "test_argmin_no_keepdims_example_select_last_index",
"test_argmin_no_keepdims_example",
- "test_argmin_no_keepdims_random_select_last_index",
+ // "test_argmin_no_keepdims_random_select_last_index",
"test_argmin_no_keepdims_random",
// "test_asin_example",
// "test_asin",
@@ -1587,21 +1588,21 @@
// "test_averagepool_2d_ceil",
"test_averagepool_2d_default",
"test_averagepool_2d_pads_count_include_pad",
- "test_averagepool_2d_pads",
+ // "test_averagepool_2d_pads", // unsupported by TFLite backend.
"test_averagepool_2d_precomputed_pads_count_include_pad",
"test_averagepool_2d_precomputed_pads",
"test_averagepool_2d_precomputed_same_upper",
"test_averagepool_2d_precomputed_strides",
- "test_averagepool_2d_same_lower",
+ // "test_averagepool_2d_same_lower", // unsupported by TFLite backend.
"test_averagepool_2d_same_upper",
"test_averagepool_2d_strides",
// "test_averagepool_3d_default",
"test_basic_conv_with_padding",
"test_basic_conv_without_padding",
"test_basic_convinteger",
- "test_batchnorm_epsilon_training_mode",
+ // "test_batchnorm_epsilon_training_mode", // unsupported training_mode by WebNN.
"test_batchnorm_epsilon",
- "test_batchnorm_example_training_mode",
+ // "test_batchnorm_example_training_mode", // unsupported training_mode by WebNN.
"test_batchnorm_example",
// // "test_bernoulli_double_expanded",
// // "test_bernoulli_double",
@@ -1622,10 +1623,10 @@
// // "test_blackmanwindow_symmetric",
// // "test_blackmanwindow",
// // "test_cast_BFLOAT16_to_FLOAT",
- "test_cast_DOUBLE_to_FLOAT",
+ // "test_cast_DOUBLE_to_FLOAT",
// "test_cast_DOUBLE_to_FLOAT16",
// // "test_cast_FLOAT_to_BFLOAT16",
- "test_cast_FLOAT_to_DOUBLE",
+ // "test_cast_FLOAT_to_DOUBLE",
// // "test_cast_FLOAT_to_FLOAT16",
// // "test_cast_FLOAT_to_STRING",
// "test_cast_FLOAT16_to_DOUBLE",
@@ -1657,15 +1658,16 @@
// "test_celu",
"test_clip_default_inbounds",
"test_clip_default_int8_inbounds",
- "test_clip_default_int8_max",
- "test_clip_default_int8_min",
- "test_clip_default_max",
- "test_clip_default_min",
- "test_clip_example",
- "test_clip_inbounds",
- "test_clip_outbounds",
- "test_clip_splitbounds",
- "test_clip",
+ // "test_clip_default_int8_max",
+ // "test_clip_default_int8_min",
+ // tests "test_clip*" on opset > 10 are excluded because max and min are non-constant inputs.
+ "opset{7,8,9,10}/test_clip_default_max",
+ "opset{7,8,9,10}/test_clip_default_min",
+ "opset{7,8,9,10}/test_clip_example",
+ "opset{7,8,9,10}/test_clip_inbounds",
+ "opset{7,8,9,10}/test_clip_outbounds",
+ "opset{7,8,9,10}/test_clip_splitbounds",
+ "opset{7,8,9,10}/test_clip",
// // "test_compress_0",
// // "test_compress_1",
// // "test_compress_default_axis",
@@ -1690,32 +1692,33 @@
"test_convinteger_without_padding",
"test_convtranspose_1d",
// // "test_convtranspose_3d",
- // "test_convtranspose_autopad_same",
- "test_convtranspose_dilations",
+ "!(opset14)/test_convtranspose_autopad_same",
+ // "test_convtranspose_dilations", // unsupported by TFLite backend.
"test_convtranspose_kernel_shape",
"opset{9,17}/test_convtranspose_output_shape",
"test_convtranspose_pad",
- "test_convtranspose_pads",
+ // "test_convtranspose_pads", // unsupported by TFLite backend.
"test_convtranspose_with_kernel",
"test_convtranspose",
"test_cos_example",
"test_cos",
// "test_cosh_example",
// "test_cosh",
- "test_cumsum_1d_exclusive",
- "test_cumsum_1d_reverse_exclusive",
- "test_cumsum_1d_reverse",
- "test_cumsum_1d",
- "test_cumsum_2d_axis_0",
- "test_cumsum_2d_axis_1",
- "test_cumsum_2d_negative_axis",
+ // tests "test_cumsum*" are excluded because they use float64.
+ // "test_cumsum_1d_exclusive",
+ // "test_cumsum_1d_reverse_exclusive",
+ // "test_cumsum_1d_reverse",
+ // "test_cumsum_1d",
+ // "test_cumsum_2d_axis_0",
+ // "test_cumsum_2d_axis_1",
+ // "test_cumsum_2d_negative_axis",
// "test_depthtospace_crd_mode_example",
// "test_depthtospace_crd_mode",
// "test_depthtospace_dcr_mode",
// "test_depthtospace_example",
// "test_depthtospace",
- // // "test_dequantizelinear_axis",
- // // "test_dequantizelinear",
+ "test_dequantizelinear_axis",
+ "test_dequantizelinear",
// // "test_det_2d",
// // "test_det_nd",
// // "test_dft_axis",
@@ -1723,27 +1726,27 @@
// // "test_dft",
"test_div_bcast",
"test_div_example",
- // "test_div_uint8",
+ "test_div_uint8",
"test_div",
- // // "test_dropout_default_mask_ratio",
- // // "test_dropout_default_mask",
- // // "test_dropout_default_old",
- // // "test_dropout_default_ratio",
- // // "test_dropout_default",
- // // "test_dropout_random_old",
- // // "test_dropout_random",
+ "test_dropout_default_mask_ratio",
+ "test_dropout_default_mask",
+ "test_dropout_default_old",
+ "test_dropout_default_ratio",
+ "test_dropout_default",
+ "test_dropout_random_old",
+ "test_dropout_random",
// // "test_dynamic_slice_default_axes",
// // "test_dynamic_slice_end_out_of_bounds",
// // "test_dynamic_slice_neg",
// // "test_dynamic_slice_start_out_of_bounds",
// // "test_dynamic_slice",
- // // "test_dynamicquantizelinear_expanded",
- // // "test_dynamicquantizelinear_max_adjusted_expanded",
- // // "test_dynamicquantizelinear_max_adjusted",
- // // "test_dynamicquantizelinear_min_adjusted_expanded",
- // // "test_dynamicquantizelinear_min_adjusted",
- // // "test_dynamicquantizelinear",
- // "test_edge_pad",
+ "test_dynamicquantizelinear_expanded",
+ "test_dynamicquantizelinear_max_adjusted_expanded",
+ "test_dynamicquantizelinear_max_adjusted",
+ "test_dynamicquantizelinear_min_adjusted_expanded",
+ "test_dynamicquantizelinear_min_adjusted",
+ "test_dynamicquantizelinear",
+ // "opset{7,8,9,10}/test_edge_pad", // The edge padding model is unsupported by TFLite backend.
// "test_einsum_batch_diagonal",
// "test_einsum_batch_matmul",
// "test_einsum_inner_prod",
@@ -1754,7 +1757,7 @@
"test_elu",
"test_equal_bcast",
"test_equal",
- // "test_erf",
+ "test_erf",
"test_exp_example",
"test_exp",
// "test_expand_dim_changed",
@@ -1777,11 +1780,11 @@
"test_gather_1",
"test_gather_2d_indices",
"test_gather_negative_indices",
- "test_gather_elements_0",
- "test_gather_elements_1",
- "test_gather_elements_negative_indices",
+ // "test_gather_elements_0", // TFLite backend only supports constant indices.
+ // "test_gather_elements_1", // TFLite backend only supports constant indices.
+ // "test_gather_elements_negative_indices", // TFLite backend only supports constant indices.
"test_gathernd_example_float32",
- "test_gathernd_example_int32_batch_dim1",
+ // "test_gathernd_example_int32_batch_dim1",
"test_gathernd_example_int32",
"test_gemm_all_attributes",
"test_gemm_alpha",
@@ -1789,7 +1792,7 @@
"test_gemm_broadcast",
"test_gemm_default_matrix_bias",
"test_gemm_default_no_bias",
- // "test_gemm_default_scalar_bias",
+ "test_gemm_default_scalar_bias",
"test_gemm_default_single_elem_vector_bias",
"test_gemm_default_vector_bias",
"test_gemm_default_zero_bias",
@@ -1845,48 +1848,49 @@
// "test_if_opt",
"test_instancenorm_epsilon",
"test_instancenorm_example",
- // "test_isinf_negative",
- // "test_isinf_positive",
- // "test_isinf",
- // "test_isnan",
+ "test_isinf_negative",
+ "test_isinf_positive",
+ "test_isinf",
+ "test_isnan",
+ // tests "test_layernorm*" are excluded because they produce 3 outputs.
// "test_layer_normalization_2d_axis_negative_1_expanded",
- "test_layer_normalization_2d_axis_negative_1",
+ // "test_layer_normalization_2d_axis_negative_1",
// "test_layer_normalization_2d_axis_negative_2_expanded",
- "test_layer_normalization_2d_axis_negative_2",
+ // "test_layer_normalization_2d_axis_negative_2",
// "test_layer_normalization_2d_axis0_expanded",
- "test_layer_normalization_2d_axis0",
+ // "test_layer_normalization_2d_axis0",
// "test_layer_normalization_2d_axis1_expanded",
- "test_layer_normalization_2d_axis1",
+ // "test_layer_normalization_2d_axis1",
// "test_layer_normalization_3d_axis_negative_1_epsilon_expanded",
- "test_layer_normalization_3d_axis_negative_1_epsilon",
+ // "test_layer_normalization_3d_axis_negative_1_epsilon",
// "test_layer_normalization_3d_axis_negative_2_epsilon_expanded",
- "test_layer_normalization_3d_axis_negative_2_epsilon",
+ // "test_layer_normalization_3d_axis_negative_2_epsilon",
// "test_layer_normalization_3d_axis_negative_3_epsilon_expanded",
- "test_layer_normalization_3d_axis_negative_3_epsilon",
+ // "test_layer_normalization_3d_axis_negative_3_epsilon",
// "test_layer_normalization_3d_axis0_epsilon_expanded",
- "test_layer_normalization_3d_axis0_epsilon",
+ // "test_layer_normalization_3d_axis0_epsilon",
// "test_layer_normalization_3d_axis1_epsilon_expanded",
- "test_layer_normalization_3d_axis1_epsilon",
+ // "test_layer_normalization_3d_axis1_epsilon",
// "test_layer_normalization_3d_axis2_epsilon_expanded",
- "test_layer_normalization_3d_axis2_epsilon",
+ // "test_layer_normalization_3d_axis2_epsilon",
// "test_layer_normalization_4d_axis_negative_1_expanded",
- "test_layer_normalization_4d_axis_negative_1",
+ // "test_layer_normalization_4d_axis_negative_1",
// "test_layer_normalization_4d_axis_negative_2_expanded",
- "test_layer_normalization_4d_axis_negative_2",
+ // "test_layer_normalization_4d_axis_negative_2",
// "test_layer_normalization_4d_axis_negative_3_expanded",
- "test_layer_normalization_4d_axis_negative_3",
+ // "test_layer_normalization_4d_axis_negative_3",
// "test_layer_normalization_4d_axis_negative_4_expanded",
- "test_layer_normalization_4d_axis_negative_4",
+ // "test_layer_normalization_4d_axis_negative_4",
// "test_layer_normalization_4d_axis0_expanded",
- "test_layer_normalization_4d_axis0",
+ // "test_layer_normalization_4d_axis0",
// "test_layer_normalization_4d_axis1_expanded",
- "test_layer_normalization_4d_axis1",
+ // "test_layer_normalization_4d_axis1",
// "test_layer_normalization_4d_axis2_expanded",
- "test_layer_normalization_4d_axis2",
+ // "test_layer_normalization_4d_axis2",
// "test_layer_normalization_4d_axis3_expanded",
- "test_layer_normalization_4d_axis3",
+ // "test_layer_normalization_4d_axis3",
// "test_layer_normalization_default_axis_expanded",
- "test_layer_normalization_default_axis",
+ // "test_layer_normalization_default_axis",
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu",
@@ -1912,42 +1916,42 @@
// // "test_logsoftmax_large_number",
// // "test_logsoftmax_negative_axis_expanded",
// // "test_logsoftmax_negative_axis",
- // "test_lrn_default",
- // "test_lrn",
+ "test_lrn_default",
+ "test_lrn",
// // "test_lstm_batchwise",
"test_lstm_defaults",
"test_lstm_with_initial_bias",
- "test_lstm_with_peepholes",
+ // "test_lstm_with_peepholes",
"test_matmul_2d",
"test_matmul_3d",
"test_matmul_4d",
- // // "test_matmulinteger",
+ "test_matmulinteger",
"test_max_example",
// "test_max_float16",
"test_max_float32",
- "test_max_float64",
+ // "test_max_float64",
// "test_max_int16",
- // "test_max_int32",
- // "test_max_int64",
- // "test_max_int8",
+ "test_max_int32",
+ "test_max_int64",
+ "test_max_int8",
"test_max_one_input",
"test_max_two_inputs",
// "test_max_uint16",
- // "test_max_uint32",
+ "test_max_uint32",
// "test_max_uint64",
- // "test_max_uint8",
+ "test_max_uint8",
// "test_maxpool_1d_default",
// "test_maxpool_2d_ceil",
"test_maxpool_2d_default",
- "test_maxpool_2d_dilations",
- "test_maxpool_2d_pads",
+ // "test_maxpool_2d_dilations", // unsupported by TFLite backend.
+ // "test_maxpool_2d_pads", // unsupported by TFLite backend.
"test_maxpool_2d_precomputed_pads",
"test_maxpool_2d_precomputed_same_upper",
"test_maxpool_2d_precomputed_strides",
- "test_maxpool_2d_same_lower",
+ // "test_maxpool_2d_same_lower", // unsupported by TFLite backend.
"test_maxpool_2d_same_upper",
"test_maxpool_2d_strides",
- // "test_maxpool_2d_uint8",
+ "test_maxpool_2d_uint8",
// "test_maxpool_3d_default",
// "test_maxpool_with_argmax_2d_precomputed_pads",
// "test_maxpool_with_argmax_2d_precomputed_strides",
@@ -1960,17 +1964,17 @@
"test_min_example",
// "test_min_float16",
"test_min_float32",
- "test_min_float64",
+ // "test_min_float64",
// "test_min_int16",
- // "test_min_int32",
- // "test_min_int64",
- // "test_min_int8",
+ "test_min_int32",
+ "test_min_int64",
+ "test_min_int8",
"test_min_one_input",
"test_min_two_inputs",
// "test_min_uint16",
- // "test_min_uint32",
+ "test_min_uint32",
// "test_min_uint64",
- // "test_min_uint8",
+ "test_min_uint8",
// "test_mod_bcast",
// "test_mod_broadcast",
// "test_mod_float_mixed_sign_example",
@@ -1992,9 +1996,9 @@
// // "test_momentum",
"test_mul_bcast",
"test_mul_example",
- // "test_mul_uint8",
+ "test_mul_uint8",
"test_mul",
- // "test_mvn_expanded",
+ "test_mvn_expanded",
// "test_mvn",
"test_neg_example",
"test_neg",
@@ -2110,17 +2114,17 @@
// "test_pow_types_float32_uint64",
// "test_pow_types_int",
// "test_pow_types_int32_float32",
- // "test_pow_types_int32_int32",
+ "test_pow_types_int32_int32",
// "test_pow_types_int64_float32",
- // "test_pow_types_int64_int64",
+ "test_pow_types_int64_int64",
"test_pow",
"test_prelu_broadcast",
"test_prelu_example",
// // "test_qlinearconv",
// // "test_qlinearmatmul_2D",
// // "test_qlinearmatmul_3D",
- // // "test_quantizelinear_axis",
- // // "test_quantizelinear",
+ "test_quantizelinear_axis",
+ "test_quantizelinear",
// "test_range_float_type_positive_delta_expanded",
// "test_range_float_type_positive_delta",
// "test_range_int32_type_negative_delta_expanded",
@@ -2129,23 +2133,24 @@
"test_reciprocal",
"test_reduce_l1_default_axes_keepdims_example",
"test_reduce_l1_default_axes_keepdims_random",
- "test_reduce_l1_do_not_keepdims_example",
- "test_reduce_l1_do_not_keepdims_random",
- "test_reduce_l1_keep_dims_example",
- "test_reduce_l1_keep_dims_random",
- "test_reduce_l1_negative_axes_keep_dims_example",
- "test_reduce_l1_negative_axes_keep_dims_random",
+ // tests "test_reduce_*" on opset > 13 are excluded because the axes is non-constant input.
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_random",
"test_reduce_l2_default_axes_keepdims_example",
"test_reduce_l2_default_axes_keepdims_random",
- "test_reduce_l2_do_not_keepdims_example",
- "test_reduce_l2_do_not_keepdims_random",
- "test_reduce_l2_keep_dims_example",
- "test_reduce_l2_keep_dims_random",
- "test_reduce_l2_negative_axes_keep_dims_example",
- "test_reduce_l2_negative_axes_keep_dims_random",
- "test_reduce_log_sum_asc_axes",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_asc_axes",
"test_reduce_log_sum_default",
- "test_reduce_log_sum_desc_axes",
+ "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_desc_axes",
// tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64.
"opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example",
"opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random",
@@ -2155,116 +2160,118 @@
"opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random",
"opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example",
"opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random",
- "test_reduce_log_sum_negative_axes",
+ "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_negative_axes",
"test_reduce_log_sum",
"test_reduce_max_default_axes_keepdim_example",
"test_reduce_max_default_axes_keepdims_random",
- "test_reduce_max_do_not_keepdims_example",
- "test_reduce_max_do_not_keepdims_random",
- "test_reduce_max_keepdims_example",
- "test_reduce_max_keepdims_random",
- "test_reduce_max_negative_axes_keepdims_example",
- "test_reduce_max_negative_axes_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_random",
"test_reduce_mean_default_axes_keepdims_example",
"test_reduce_mean_default_axes_keepdims_random",
- "test_reduce_mean_do_not_keepdims_example",
- "test_reduce_mean_do_not_keepdims_random",
- "test_reduce_mean_keepdims_example",
- "test_reduce_mean_keepdims_random",
- "test_reduce_mean_negative_axes_keepdims_example",
- "test_reduce_mean_negative_axes_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_random",
"test_reduce_min_default_axes_keepdims_example",
"test_reduce_min_default_axes_keepdims_random",
- "test_reduce_min_do_not_keepdims_example",
- "test_reduce_min_do_not_keepdims_random",
- "test_reduce_min_keepdims_example",
- "test_reduce_min_keepdims_random",
- "test_reduce_min_negative_axes_keepdims_example",
- "test_reduce_min_negative_axes_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_random",
"test_reduce_prod_default_axes_keepdims_example",
"test_reduce_prod_default_axes_keepdims_random",
- "test_reduce_prod_do_not_keepdims_example",
- "test_reduce_prod_do_not_keepdims_random",
- "test_reduce_prod_keepdims_example",
- "test_reduce_prod_keepdims_random",
- "test_reduce_prod_negative_axes_keepdims_example",
- "test_reduce_prod_negative_axes_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_random",
"test_reduce_sum_default_axes_keepdims_example",
"test_reduce_sum_default_axes_keepdims_random",
- "test_reduce_sum_do_not_keepdims_example",
- "test_reduce_sum_do_not_keepdims_random",
+ "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_example",
+ "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_random",
"test_reduce_sum_empty_axes_input_noop_example",
"test_reduce_sum_empty_axes_input_noop_random",
- "test_reduce_sum_keepdims_example",
- "test_reduce_sum_keepdims_random",
- "test_reduce_sum_negative_axes_keepdims_example",
- "test_reduce_sum_negative_axes_keepdims_random",
+ "opset{7,8,9,10,11}/test_reduce_sum_keepdims_example",
+ "opset{7,8,9,10,11}/test_reduce_sum_keepdims_random",
+ "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_random",
"test_reduce_sum_square_default_axes_keepdims_example",
"test_reduce_sum_square_default_axes_keepdims_random",
- "test_reduce_sum_square_do_not_keepdims_example",
- "test_reduce_sum_square_do_not_keepdims_random",
- "test_reduce_sum_square_keepdims_example",
- "test_reduce_sum_square_keepdims_random",
- "test_reduce_sum_square_negative_axes_keepdims_example",
- "test_reduce_sum_square_negative_axes_keepdims_random",
- // "test_reflect_pad",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_random",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_example",
+ "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_random",
+ "opset{7,8,9,10}/test_reflect_pad",
"test_relu",
- "test_reshape_allowzero_reordered",
- "test_reshape_extended_dims",
- "test_reshape_negative_dim",
- "test_reshape_negative_extended_dims",
- "test_reshape_one_dim",
- "test_reshape_reduced_dims",
- "test_reshape_reordered_all_dims",
- "test_reshape_reordered_dims",
- "test_reshape_reordered_last_dims",
- "test_reshape_zero_and_negative_dim",
- "test_reshape_zero_dim",
- "test_resize_downsample_linear",
- "test_resize_downsample_nearest",
- "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
+ // tests "test_reshape*" are excluded because the shape is non-constant input.
+ // "test_reshape_allowzero_reordered",
+ // "test_reshape_extended_dims",
+ // "test_reshape_negative_dim",
+ // "test_reshape_negative_extended_dims",
+ // "test_reshape_one_dim",
+ // "test_reshape_reduced_dims",
+ // "test_reshape_reordered_all_dims",
+ // "test_reshape_reordered_dims",
+ // "test_reshape_reordered_last_dims",
+ // "test_reshape_zero_and_negative_dim",
+ // "test_reshape_zero_dim",
+ // tests "test_resize*" are excluded because scales and sizes are non-constant inputs.
+ // "test_resize_downsample_linear",
+ // "test_resize_downsample_nearest",
+ // "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
// "test_resize_downsample_scales_cubic_align_corners",
- "test_resize_downsample_scales_cubic",
+ // "test_resize_downsample_scales_cubic",
// "test_resize_downsample_scales_linear_align_corners",
- "test_resize_downsample_scales_linear",
- "test_resize_downsample_scales_nearest",
- "test_resize_downsample_sizes_cubic",
- "test_resize_downsample_sizes_linear_pytorch_half_pixel",
- "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
- "test_resize_downsample_sizes_nearest",
- "test_resize_nearest",
- "test_resize_tf_crop_and_resize",
- "test_resize_upsample_linear",
- "test_resize_upsample_nearest",
- "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
- "test_resize_upsample_scales_cubic_align_corners",
- "test_resize_upsample_scales_cubic_asymmetric",
- "test_resize_upsample_scales_cubic",
- "test_resize_upsample_scales_linear_align_corners",
- "test_resize_upsample_scales_linear",
- "test_resize_upsample_scales_nearest",
- "test_resize_upsample_sizes_cubic",
- "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
- "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
- "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
- "test_resize_upsample_sizes_nearest",
+ // "test_resize_downsample_scales_linear",
+ // "test_resize_downsample_scales_nearest",
+ // "test_resize_downsample_sizes_cubic",
+ // "test_resize_downsample_sizes_linear_pytorch_half_pixel",
+ // "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
+ // "test_resize_downsample_sizes_nearest",
+ // "test_resize_nearest",
+ // "test_resize_tf_crop_and_resize",
+ // "test_resize_upsample_linear",
+ // "test_resize_upsample_nearest",
+ // "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
+ // "test_resize_upsample_scales_cubic_align_corners",
+ // "test_resize_upsample_scales_cubic_asymmetric",
+ // "test_resize_upsample_scales_cubic",
+ // "test_resize_upsample_scales_linear_align_corners",
+ // "test_resize_upsample_scales_linear",
+ // "test_resize_upsample_scales_nearest",
+ // "test_resize_upsample_sizes_cubic",
+ // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
+ // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
+ // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
+ // "test_resize_upsample_sizes_nearest",
// // "test_reversesequence_batch",
// // "test_reversesequence_time",
// // "test_rnn_seq_length",
// // "test_roialign_aligned_false",
// // "test_roialign_aligned_true",
// // "test_roialign",
- // // "test_round",
+ "test_round",
// // "test_scan_sum",
// // "test_scan9_sum",
- "test_scatter_elements_with_axis",
- "test_scatter_elements_with_duplicate_indices",
- "test_scatter_elements_with_negative_indices",
- "test_scatter_elements_without_axis",
+ // "test_scatter_elements_with_axis", // TFLite backend does not support non-constant indices.
+ // "test_scatter_elements_with_duplicate_indices", // WebNN only supports reduction type 'none'.
+ // "test_scatter_elements_with_negative_indices", // TFLite backend does not support non-constant indices.
+ // "test_scatter_elements_without_axis", // TFLite backend does not support non-constant indices.
// // "test_scatter_with_axis",
// // "test_scatter_without_axis",
- "test_scatternd_add",
- "test_scatternd_multiply",
+ // "test_scatternd_add", // WebNN only supports reduction type 'none'.
+ // "test_scatternd_multiply", // WebNN only supports reduction type 'none'.
"test_scatternd",
// // "test_sce_mean_3d_expanded",
// // "test_sce_mean_3d_log_prob_expanded",
@@ -2365,14 +2372,15 @@
// "test_sinh",
// // "test_size_example",
// // "test_size",
- "test_slice_default_axes",
- "test_slice_default_steps",
- "test_slice_end_out_of_bounds",
- "test_slice_neg_steps",
- "test_slice_neg",
- "test_slice_negative_axes",
- "test_slice_start_out_of_bounds",
- "test_slice",
+ // tests "test_slice_*" on opset > 9 are excluded because starts, ends, axes and steps are non-constant inputs.
+ "opset{7,8,9}/test_slice_default_axes",
+ // "test_slice_default_steps",
+ "opset{7,8,9}/test_slice_end_out_of_bounds",
+ // "test_slice_neg_steps",
+ "opset{7,8,9}/test_slice_neg",
+ // "test_slice_negative_axes",
+ "opset{7,8,9}/test_slice_start_out_of_bounds",
+ "opset{7,8,9}/test_slice",
"test_softmax_axis_0_expanded",
"test_softmax_axis_0",
"test_softmax_axis_1_expanded",
@@ -2455,23 +2463,24 @@
"test_softmax_large_number",
"test_softmax_negative_axis_expanded",
"test_softmax_negative_axis",
- // // "test_softplus_example",
- // // "test_softplus",
- // // "test_softsign_example",
- // // "test_softsign",
+ "test_softplus_example",
+ "test_softplus",
+ "test_softsign_example",
+ "test_softsign",
// "test_spacetodepth_example",
// "test_spacetodepth",
- "test_split_equal_parts_1d",
- "test_split_equal_parts_2d",
+ // tests "test_split_*" on opset > 10 are excluded because the split input is non-constant input.
+ "opset{7,8,9,10}/test_split_equal_parts_1d",
+ "opset{7,8,9,10}/test_split_equal_parts_2d",
"test_split_equal_parts_default_axis",
- "test_split_variable_parts_1d",
- "test_split_variable_parts_2d",
- "test_split_variable_parts_default_axis",
- "test_split_zero_size_splits",
+ "opset{7,8,9,10}/test_split_variable_parts_1d",
+ "opset{7,8,9,10}/test_split_variable_parts_2d",
+ "opset{7,8,9,10}/test_split_variable_parts_default_axis",
+ // "test_split_zero_size_splits",
"test_sqrt_example",
"test_sqrt",
- "test_squeeze_negative_axes",
- "test_squeeze",
+ // "test_squeeze_negative_axes",
+ "opset{7,8,9,10}/test_squeeze",
// // "test_stft_with_window",
// // "test_stft",
// // "test_strnormalizer_export_monday_casesensintive_lower",
@@ -2482,7 +2491,7 @@
// // "test_strnormalizer_nostopwords_nochangecase",
"test_sub_bcast",
"test_sub_example",
- // "test_sub_uint8",
+ "test_sub_uint8",
"test_sub",
// "test_sum_example",
// "test_sum_one_input",
@@ -2501,8 +2510,9 @@
// "test_thresholdedrelu_default",
// "test_thresholdedrelu_example",
// "test_thresholdedrelu",
- "test_tile_precomputed",
- "test_tile",
+ // tests "test_tile*" are excluded because the repeats is non-constant input.
+ // "test_tile_precomputed",
+ // "test_tile",
// // "test_top_k_negative_axis",
// // "test_top_k_smallest",
// // "test_top_k",
@@ -2520,41 +2530,41 @@
"test_transpose_all_permutations_5",
"test_transpose_default",
// "test_tril_neg",
- // "test_tril_one_row_neg",
+ "test_tril_one_row_neg",
// "test_tril_out_neg",
// "test_tril_out_pos",
// "test_tril_pos",
// "test_tril_square_neg",
- // "test_tril_square",
+ "test_tril_square",
// "test_tril_zero",
- // "test_tril",
+ "test_tril",
// "test_triu_neg",
// "test_triu_one_row",
// "test_triu_out_neg_out",
// "test_triu_out_pos",
// "test_triu_pos",
// "test_triu_square_neg",
- // "test_triu_square",
+ "test_triu_square",
// "test_triu_zero",
- // "test_triu",
+ "test_triu",
// // "test_unique_not_sorted_without_axis",
// // "test_unique_sorted_with_axis_3d",
// // "test_unique_sorted_with_axis",
// // "test_unique_sorted_with_negative_axis",
// // "test_unique_sorted_without_axis",
- "test_unsqueeze_axis_0",
- "test_unsqueeze_axis_1",
- "test_unsqueeze_axis_2",
- "test_unsqueeze_axis_3",
- "test_unsqueeze_negative_axes",
- "test_unsqueeze_three_axes",
- "test_unsqueeze_two_axes",
- "test_unsqueeze_unsorted_axes",
- "test_unsqueeze",
+ // "test_unsqueeze_axis_0",
+ // "test_unsqueeze_axis_1",
+ // "test_unsqueeze_axis_2",
+ // "test_unsqueeze_axis_3",
+ // "test_unsqueeze_negative_axes",
+ // "test_unsqueeze_three_axes",
+ // "test_unsqueeze_two_axes",
+ // "test_unsqueeze_unsorted_axes",
+ "opset{7,8,9,10}/test_unsqueeze",
// "test_wrap_pad"
// "test_upsample_nearest",
"test_where_example",
- // "test_where_long_example",
+ "test_where_long_example",
"test_xor_bcast3v1d",
"test_xor_bcast3v2d",
"test_xor_bcast4v2d",
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
index 9b9d755498366..a5ab63d74df24 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
@@ -522,6 +522,9 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) {
+ if (context.IsGraphCaptureEnabled()) {
+ ORT_NOT_IMPLEMENTED("Graph capture not implemented for non flash attention path");
+ }
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length =
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
index 00f60142df159..606dbfde15c2c 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
@@ -16,6 +16,32 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {
+Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
+ const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
+ const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
+ const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform);
+ const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform);
+
+ const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform);
+ const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform);
+ const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform);
+
+ if (prepare_indirect_dispatch_) {
+ sh.AddOutput("indirect_buffer", ShaderUsage::None);
+ }
+
+ return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
+ WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
+ WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
+ WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
+ WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
+ WGSL_TEMPLATE_VARIABLE(present_key, present_key),
+ WGSL_TEMPLATE_VARIABLE(present_value, present_value),
+ WGSL_TEMPLATE_VARIABLE(query, query),
+ WGSL_TEMPLATE_VARIABLE(seqlens, seqlens),
+ WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache));
+}
+
Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Expectations are
// qkv have same number of heads and hidden dimension (head size).
@@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
- const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
+ const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
+ const Tensor* cos_cache, const Tensor* sin_cache) {
+ constexpr uint32_t tile_size = 64;
+
// Extract present_sequence_length directly from present_key tensor shape:
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]);
const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();
+ // Declare query_output at function scope to ensure it persists throughout the function
+ Tensor query_output;
+
+ // Create indirect dispatch buffer if using indirect dispatch
+ Tensor* indirect_buffer_ptr = nullptr;
+ Tensor indirect_buffer;
+
+ // Prepare indirect dispatch buffer for decode path with static KV cache
+ const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
+ parameters.past_present_share_buffer_ &&
+ seqlen_k != nullptr &&
+ context.IsGraphCaptureEnabled();
+ if (use_indirect_dispatch) {
+ const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
+ indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape);
+ indirect_buffer_ptr = &indirect_buffer;
+ }
+
+ const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr);
+
+ if (do_rotary) {
+ ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input.");
+ ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache.");
+
+ // Q points to the packed QKV tensor in this case, create query output tensor
+ query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
+
+ ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters,
+ Q, seqlen_k,
+ cos_cache, sin_cache,
+ &query_output, present_key, present_value,
+ indirect_buffer_ptr, tile_size));
+ Q = &query_output;
+ } else {
+ ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr));
+ }
+
if (parameters.sequence_length_ > 1) {
- const uint32_t tile_size = 64;
- // For encode path, use the original CopyKVCache without indirect dispatch preparation
- ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.sequence_length_, present_sequence_length});
const TensorShape qk_shape(qk_dims);
Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
- constexpr uint32_t tile_size = 64;
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
- // Determine if we should use indirect dispatch
- const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
- seqlen_k != nullptr &&
- context.IsGraphCaptureEnabled();
-
- // Create indirect dispatch buffer if using indirect dispatch
- Tensor* indirect_buffer_ptr = nullptr;
- Tensor indirect_buffer;
- if (use_indirect_dispatch) {
- const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
- indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape);
- indirect_buffer_ptr = &indirect_buffer;
- // Use the fused CopyKVCache that also prepares the indirect dispatch buffer
- ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
- } else {
- // Use the original CopyKVCache without indirect dispatch preparation
- ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
- }
-
// The metadata is used to store the max and sum of each tile.
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
num_present_sequence_length_tile, 2});
@@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}
+Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
+ const WebgpuAttentionParameters& params,
+ const Tensor* packedQKV,
+ const Tensor* seqlen_k,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
+ Tensor* query,
+ Tensor* present_key,
+ Tensor* present_value,
+ Tensor* indirect_buffer,
+ uint32_t tile_size) {
+ const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]);
+ const auto head_size = params.head_size_;
+
+ int components = 1;
+ // Currently we only support vectorization when RoPE is not interleaved
+ if (!params.rotary_interleaved_) {
+ if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) {
+ components = 4;
+ } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) {
+ components = 2;
+ }
+ }
+ // Adjust dimensions for vectorization
+ const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
+ const auto head_size_vec = head_size / components;
+
+ // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
+ // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
+ // = head_size - half_rotary_dim
+ const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec;
+ auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);
+
+ // Extract present_sequence_length from present_key tensor shape
+ const uint32_t present_sequence_length = gsl::narrow_cast(present_key->Shape()[2]);
+
+ const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
+
+ SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
+ program
+ .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)
+ .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
+ .AddInputs({
+ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
+ {cos_cache, ProgramTensorMetadataDependency::Rank, components},
+ {sin_cache, ProgramTensorMetadataDependency::Rank, components},
+ });
+ program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components},
+ {present_key, ProgramTensorMetadataDependency::None, components},
+ {present_value, ProgramTensorMetadataDependency::None, components}});
+
+ if (prepare_indirect_dispatch) {
+ program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
+ }
+
+ program.AddUniformVariables({
+ {static_cast(params.sequence_length_)},
+ {static_cast(params.hidden_size_ / components)},
+ {static_cast(params.kv_hidden_size_ / components)},
+ {static_cast(params.num_heads_)},
+ {static_cast(params.kv_num_heads_)},
+ {head_size_vec},
+ {half_rotary_embedding_dim_vec},
+ {present_sequence_length},
+ {tile_size},
+ {static_cast(dispatch_size)},
+ });
+
+ program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
+ return context.RunProgram(program);
+}
+
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
index 9599c10533351..a936a91695921 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
@@ -15,6 +15,32 @@ namespace webgpu {
using namespace onnxruntime::webgpu;
+class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program {
+ public:
+ SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch)
+ : Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"},
+ interleaved_(interleaved),
+ prepare_indirect_dispatch_(prepare_indirect_dispatch) {}
+
+ Status GenerateShaderCode(ShaderHelper& sh) const override;
+
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
+ {"sequence_length", ProgramUniformVariableDataType::Uint32},
+ {"hidden_size", ProgramUniformVariableDataType::Uint32},
+ {"kv_hidden_size", ProgramUniformVariableDataType::Uint32},
+ {"num_heads", ProgramUniformVariableDataType::Uint32},
+ {"kv_num_heads", ProgramUniformVariableDataType::Uint32},
+ {"head_size", ProgramUniformVariableDataType::Uint32},
+ {"half_rotary_dim", ProgramUniformVariableDataType::Uint32},
+ {"present_sequence_length", ProgramUniformVariableDataType::Uint32},
+ {"tile_size", ProgramUniformVariableDataType::Uint32},
+ {"dispatch_size", ProgramUniformVariableDataType::Uint32});
+
+ private:
+ const bool interleaved_;
+ const bool prepare_indirect_dispatch_;
+};
+
class CopyKVCacheProgram final : public Program {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
@@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramShape().Size();
program
- .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank})
- .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}})
+ .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank})
+ .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}})
.AddUniformVariables({
{static_cast(params.hidden_size_)},
{static_cast(params.kv_hidden_size_)},
@@ -90,32 +90,46 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext&
const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;
+ int components = 1;
+ // Currently we only support vectorization when RoPE is not interleaved
+ if (!params.rotary_interleaved_) {
+ if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) {
+ components = 4;
+ } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) {
+ components = 2;
+ }
+ }
+
+ // Adjust dimensions for vectorization
+ const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
+ const auto head_size_vec = head_size / components;
+
// Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
// = head_size - half_rotary_dim
- const auto work_per_head = head_size - half_rotary_embedding_dim;
- auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);
+ const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec;
+ auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec);
SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_);
program
.CacheHint(params.rotary_interleaved_)
- .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank})
+ .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddInputs({
- {seqlen_k, ProgramTensorMetadataDependency::Rank},
- {cos_cache, ProgramTensorMetadataDependency::Rank},
- {sin_cache, ProgramTensorMetadataDependency::Rank},
+ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
+ {cos_cache, ProgramTensorMetadataDependency::Rank, components},
+ {sin_cache, ProgramTensorMetadataDependency::Rank, components},
})
- .AddOutputs({{query, ProgramTensorMetadataDependency::Rank},
- {key, ProgramTensorMetadataDependency::Rank},
- {val, ProgramTensorMetadataDependency::Rank}})
+ .AddOutputs({{query, ProgramTensorMetadataDependency::None, components},
+ {key, ProgramTensorMetadataDependency::None, components},
+ {val, ProgramTensorMetadataDependency::None, components}})
.AddUniformVariables({
{static_cast(params.sequence_length_)},
- {static_cast(params.hidden_size_)},
- {static_cast(params.kv_hidden_size_)},
+ {static_cast(params.hidden_size_ / components)},
+ {static_cast(params.kv_hidden_size_ / components)},
{static_cast(params.num_heads_)},
{static_cast(params.kv_num_heads_)},
- {static_cast(head_size)},
- {half_rotary_embedding_dim},
+ {head_size_vec},
+ {half_rotary_embedding_dim_vec},
{static_cast(dispatch_size)},
})
.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
@@ -177,15 +191,15 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
program
.CacheHint(params.rotary_interleaved_)
.AddInputs({
- {query_in, ProgramTensorMetadataDependency::Rank},
+ {query_in, ProgramTensorMetadataDependency::TypeAndRank},
{key_in, ProgramTensorMetadataDependency::Rank},
- {seqlen_k, ProgramTensorMetadataDependency::Rank},
+ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{cos_cache, ProgramTensorMetadataDependency::Rank},
{sin_cache, ProgramTensorMetadataDependency::Rank},
})
.AddOutputs({
- {query_out, ProgramTensorMetadataDependency::Rank},
- {key_out, ProgramTensorMetadataDependency::Rank},
+ {query_out, ProgramTensorMetadataDependency::None},
+ {key_out, ProgramTensorMetadataDependency::None},
})
.SetDispatchGroupSize((q_domain_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
@@ -265,7 +279,26 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor qRotary;
Tensor kRotary;
+
+ // Use a sliding window if the total sequence exceeds the window's length.
+ bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_);
+ bool will_use_flash_attention = false;
+ if (head_sink == nullptr && !use_smooth_softmax_ && !use_sliding_window) {
+ // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking
+ WebgpuAttentionParameters temp_params = parameters;
+ temp_params.is_packed_qkv_ = false;
+ will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context);
+ }
+
if (parameters.is_packed_qkv_ && do_rotary_) {
+ // Use the ultimate fused operation when FlashAttention and static KV cache is enabled.
+ if (will_use_flash_attention && parameters.past_present_share_buffer_) {
+ // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled
+ // query points to packed QKV, K and V are nullptr since they're not needed
+ return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value,
+ present_value, parameters, context, seqlen_k, cos_cache, sin_cache);
+ }
+ // Fused: splitQKV + rotary QK
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
@@ -279,8 +312,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
key = &kSplit;
value = &vSplit;
} else {
- // Original separate path
if (parameters.is_packed_qkv_) {
+ // splitQKV
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
@@ -292,6 +325,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
value = &vSplit;
}
if (do_rotary_) {
+ // rotary QK
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters,
@@ -304,11 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
}
}
- // Use a sliding window if the total sequence exceeds the window's length.
- bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_);
- if (head_sink == nullptr && !use_smooth_softmax_ &&
- !use_sliding_window &&
- CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
+ if (will_use_flash_attention) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context, seqlen_k);
}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template
index b64448611079f..777be41ffb456 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template
+++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template
@@ -36,11 +36,11 @@ $MAIN {
// Calculate actual indices in the head for i and j
#if interleaved
- let idx_i = in_head_idx;
- let idx_j = in_head_idx + 1u;
+ let idx_i = in_head_idx + in_head_idx;
+ let idx_j = idx_i + 1u;
#else
let idx_i = in_head_idx;
- let idx_j = in_head_idx + uniforms.half_rotary_dim;
+ let idx_j = idx_i + uniforms.half_rotary_dim;
#endif
// Process Q pair
diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
new file mode 100644
index 0000000000000..d6cb654afa756
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
@@ -0,0 +1,111 @@
+#param interleaved
+#param prepare_indirect_dispatch
+
+#use guardAgainstOutOfBoundsWorkgroupSizes
+#use .setByIndices .getByIndices .getByOffset
+
+$MAIN {
+ guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size);
+
+ // Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim)
+ // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
+ let work_per_head = uniforms.head_size - uniforms.half_rotary_dim;
+ let total_work = uniforms.num_heads * work_per_head;
+
+ let batch_idx = global_idx / (uniforms.sequence_length * total_work);
+ let remainder1 = global_idx % (uniforms.sequence_length * total_work);
+ let seq_idx = remainder1 / total_work;
+ let remainder2 = remainder1 % total_work;
+ let head_idx = remainder2 / work_per_head;
+ let in_head_idx = remainder2 % work_per_head;
+
+ // Calculate base offset in packed_qkv for this token
+ // Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)]
+ let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size;
+ let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size;
+
+ // Calculate position_id (needed for rotary embedding)
+ let seqlen_i = seqlens.getByOffset(batch_idx);
+ let seqlen = u32(seqlen_i);
+ let total_seqlen = seqlen + 1u;
+
+ let past_seqlen = total_seqlen - uniforms.sequence_length;
+ // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value
+ let position_id = past_seqlen + seq_idx;
+
+#if prepare_indirect_dispatch
+ // Prepare indirect dispatch buffer for thread 0
+ if (global_idx == 0u) {
+ let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size;
+ indirect_buffer[0] = num_total_seq_length_tile;
+ indirect_buffer[1] = uniforms.num_heads;
+ indirect_buffer[2] = 1u;
+ }
+#endif
+
+ if (in_head_idx < uniforms.half_rotary_dim) {
+ // Process a rotary pair (i, j)
+ let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx));
+ let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx));
+
+ // Calculate actual indices in the head for i and j
+#if interleaved
+ let idx_i = in_head_idx + in_head_idx;
+ let idx_j = idx_i + 1u;
+#else
+ let idx_i = in_head_idx;
+ let idx_j = idx_i + uniforms.half_rotary_dim;
+#endif
+
+ // Process Q pair
+ let q_base = base_offset + head_idx * uniforms.head_size;
+ let q_i_offset = q_base + idx_i;
+ let q_j_offset = q_base + idx_j;
+ let q_i = packed_qkv.getByOffset(q_i_offset);
+ let q_j = packed_qkv.getByOffset(q_j_offset);
+ let q_re = q_i * cos_v - q_j * sin_v;
+ let q_im = q_i * sin_v + q_j * cos_v;
+ query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re);
+ query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im);
+
+ // Process K and V pairs if within kv_num_heads
+ if (head_idx < uniforms.kv_num_heads) {
+ let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size;
+ let k_i_offset = k_base + idx_i;
+ let k_j_offset = k_base + idx_j;
+ let k_i = packed_qkv.getByOffset(k_i_offset);
+ let k_j = packed_qkv.getByOffset(k_j_offset);
+ let k_re = k_i * cos_v - k_j * sin_v;
+ let k_im = k_i * sin_v + k_j * cos_v;
+ // Write K directly to present_key cache
+ present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), k_re);
+ present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), k_im);
+
+ // V doesn't need rotary, just copy the pair to present_value cache
+ let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size;
+ let v_i = packed_qkv.getByOffset(v_base + idx_i);
+ let v_j = packed_qkv.getByOffset(v_base + idx_j);
+ present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), v_i);
+ present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), v_j);
+ }
+ } else {
+ // Process non-rotary elements (direct copy)
+ let actual_idx = uniforms.half_rotary_dim + in_head_idx;
+
+ // Copy Q
+ let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx;
+ let q_data = packed_qkv.getByOffset(q_offset);
+ query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data);
+
+ // Copy K and V if within kv_num_heads directly to present cache
+ if (head_idx < uniforms.kv_num_heads) {
+ let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx;
+ let k_data = packed_qkv.getByOffset(k_offset);
+ present_key.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), k_data);
+
+ let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx;
+ let v_data = packed_qkv.getByOffset(v_offset);
+ present_value.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), v_data);
+ }
+ }
+} // MAIN
diff --git a/onnxruntime/core/common/cpuid_info_vendor.cc b/onnxruntime/core/common/cpuid_info_vendor.cc
index d4d940eedfe28..8675f129da770 100644
--- a/onnxruntime/core/common/cpuid_info_vendor.cc
+++ b/onnxruntime/core/common/cpuid_info_vendor.cc
@@ -198,7 +198,7 @@ constexpr std::array kCpuVendorInfos{
CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE},
CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B},
CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5},
-
+ CpuVendorInfo{cpuinfo_vendor_ibm, "IBM", 0x1014},
// TODO add more as needed
};
@@ -228,6 +228,9 @@ void CPUIDInfo::VendorInfoInit() {
}
}
#endif // defined(CPUINFO_SUPPORTED)
+#if defined(_AIX)
+ result = cpuinfo_vendor_ibm;
+#endif
return result;
}();
diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc
index 91b5b811a3529..a656abb098911 100644
--- a/onnxruntime/core/framework/allocator.cc
+++ b/onnxruntime/core/framework/allocator.cc
@@ -58,6 +58,18 @@ Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg&
ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_mem));
}
+ if (auto it = kvps_entries.find(ConfigKeyNames::UseCudaMemPool); it != kvps_entries.end()) {
+ ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.use_cuda_mempool));
+ }
+
+ if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolReleaseThreshold); it != kvps_entries.end()) {
+ ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_release_threshold));
+ }
+
+ if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolBytesToKeepOnShrink); it != kvps_entries.end()) {
+ ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_bytes_to_keep_on_shrink));
+ }
+
if (!cfg.IsValid()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid arena configuration. Please check the values provided.");
@@ -177,6 +189,16 @@ void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve
return alloc.Alloc(size);
}
+
+IArena* IArena::SafeArenaCast(IAllocator* allocator) {
+#if !defined(ORT_NO_RTTI)
+ auto* result = dynamic_cast(allocator);
+ return result;
+#else
+ return static_cast(allocator);
+#endif
+}
+
} // namespace onnxruntime
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); }
diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc
index 8c4e74c4b1cc7..ee9cf5bb39ca0 100644
--- a/onnxruntime/core/framework/allocator_utils.cc
+++ b/onnxruntime/core/framework/allocator_utils.cc
@@ -52,14 +52,14 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) {
if (info.use_stream_aware_arena) {
#ifdef ORT_ENABLE_STREAM
return AllocatorPtr(
- std::make_unique(std::move(device_allocator),
- max_mem,
- arena_extend_str,
- initial_chunk_size_bytes,
- max_dead_bytes_per_chunk,
- initial_growth_chunk_size_bytes));
+ std::make_unique(std::move(device_allocator),
+ max_mem,
+ arena_extend_str,
+ initial_chunk_size_bytes,
+ max_dead_bytes_per_chunk,
+ initial_growth_chunk_size_bytes));
#else
- ORT_THROW("StreamAwareArena should be transparent to minimal build.");
+ ORT_THROW("StreamAwareBFCArena should be transparent to minimal build.");
#endif
} else {
return AllocatorPtr(
diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc
index 3a5af42d03cdd..cfe155986eff2 100644
--- a/onnxruntime/core/framework/bfc_arena.cc
+++ b/onnxruntime/core/framework/bfc_arena.cc
@@ -13,11 +13,10 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator,
int max_dead_bytes_per_chunk,
int initial_growth_chunk_size_bytes,
int64_t max_power_of_two_extend_bytes)
- : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(),
- OrtAllocatorType::OrtArenaAllocator,
- resource_allocator->Info().device,
- resource_allocator->Info().mem_type)),
- arena_type_(ArenaType::BaseArena),
+ : IArena(OrtMemoryInfo(resource_allocator->Info().name.c_str(),
+ OrtAllocatorType::OrtArenaAllocator,
+ resource_allocator->Info().device,
+ resource_allocator->Info().mem_type)),
device_allocator_(std::move(resource_allocator)),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1),
@@ -827,13 +826,13 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla
}
}
-StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocator,
- size_t total_memory,
- ArenaExtendStrategy arena_extend_strategy,
- int initial_chunk_size_bytes,
- int max_dead_bytes_per_chunk,
- int initial_growth_chunk_size_bytes,
- int64_t max_power_of_two_extend_bytes)
+StreamAwareBFCArena::StreamAwareBFCArena(std::unique_ptr resource_allocator,
+ size_t total_memory,
+ ArenaExtendStrategy arena_extend_strategy,
+ int initial_chunk_size_bytes,
+ int max_dead_bytes_per_chunk,
+ int initial_growth_chunk_size_bytes,
+ int64_t max_power_of_two_extend_bytes)
: BFCArena(std::move(resource_allocator),
total_memory,
arena_extend_strategy,
@@ -841,14 +840,13 @@ StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocato
max_dead_bytes_per_chunk,
initial_growth_chunk_size_bytes,
max_power_of_two_extend_bytes) {
- arena_type_ = ArenaType::StreamAwareArena;
}
-void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) {
+void* StreamAwareBFCArena::AllocOnStream(size_t size, Stream* current_stream) {
return AllocateRawInternal(size, false, current_stream);
}
-void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) {
+void StreamAwareBFCArena::ReleaseStreamBuffers(Stream* stream) {
// since chunks on target stream will be reset to nullptr, trigger coalesce to see whether we can get bigger chunk.
ResetChunkOnTargetStream(stream, true);
}
diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h
index f3c0544124241..e3494853f7064 100644
--- a/onnxruntime/core/framework/bfc_arena.h
+++ b/onnxruntime/core/framework/bfc_arena.h
@@ -43,7 +43,7 @@ namespace onnxruntime {
#endif
#endif
-class StreamAwareArena;
+class StreamAwareBFCArena;
// A memory allocator that implements a 'best-fit with coalescing'
// algorithm. This is essentially a very simple version of Doug Lea's
// malloc (dlmalloc).
@@ -52,7 +52,7 @@ class StreamAwareArena;
// coalescing. One assumption we make is that the process using this
// allocator owns pretty much all of the memory, and that nearly
// all requests to allocate memory go through this interface.
-class BFCArena : public IAllocator {
+class BFCArena : public IArena {
public:
static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo;
static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024;
@@ -61,11 +61,6 @@ class BFCArena : public IAllocator {
static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB
static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max();
- enum ArenaType {
- BaseArena,
- StreamAwareArena,
- };
-
BFCArena(std::unique_ptr resource_allocator,
size_t total_memory,
ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY,
@@ -84,14 +79,6 @@ class BFCArena : public IAllocator {
// If p is NULL, no operation is performed.
void Free(void* p) override;
- // Frees all allocation regions in which no chunk is in use.
- // Does not free any reserved chunks.
- // Resets the size that the arena will grow by in the next allocation to
- // `initial_growth_chunk_size_bytes_` but ultimately all
- // future allocation sizes are determined by the arena growth strategy
- // and the allocation request.
- Status Shrink();
-
void* Reserve(size_t size) override;
void GetStats(AllocatorStats* stats) override;
@@ -100,7 +87,13 @@ class BFCArena : public IAllocator {
size_t AllocatedSize(const void* ptr);
- ArenaType GetArenaType() const { return arena_type_; }
+ // Frees all allocation regions in which no chunk is in use.
+ // Does not free any reserved chunks.
+ // Resets the size that the arena will grow by in the next allocation to
+ // `initial_growth_chunk_size_bytes_` but ultimately all
+ // future allocation sizes are determined by the arena growth strategy
+ // and the allocation request.
+ Status Shrink() override;
protected:
void* AllocateRawInternal(size_t num_bytes,
@@ -112,7 +105,6 @@ class BFCArena : public IAllocator {
// perform coalesce if coalesce_flag is true
void ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag);
#endif
- ArenaType arena_type_;
private:
void DeallocateRawInternal(void* ptr);
@@ -510,26 +502,22 @@ class BFCArena : public IAllocator {
};
#ifdef ORT_ENABLE_STREAM
-class StreamAwareArena : public BFCArena {
+class StreamAwareBFCArena : public BFCArena {
public:
- StreamAwareArena(std::unique_ptr resource_allocator,
- size_t total_memory,
- ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY,
- int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES,
- int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK,
- int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES,
- int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES);
+ StreamAwareBFCArena(std::unique_ptr resource_allocator,
+ size_t total_memory,
+ ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY,
+ int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES,
+ int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK,
+ int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES,
+ int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES);
bool IsStreamAware() const override { return true; }
// Standard alloc behavior. Returns valid pointer if size > 0 and memory was available. Otherwise returns nullptr.
void* AllocOnStream(size_t size, Stream* current_stream_id) override;
- void ReleaseStreamBuffers(Stream* stream);
-
- static StreamAwareArena* FromBFCArena(BFCArena& arena) {
- return arena.GetArenaType() == ArenaType::StreamAwareArena ? reinterpret_cast(&arena) : nullptr;
- }
+ void ReleaseStreamBuffers(Stream* stream) override;
};
#endif
#ifdef __GNUC__
diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc
index 8d15e03c2e5ce..a32973ddb8c9e 100644
--- a/onnxruntime/core/framework/device_stream_collection.cc
+++ b/onnxruntime/core/framework/device_stream_collection.cc
@@ -21,7 +21,8 @@ struct DummyStream : Stream {
class DeviceStreamCollectionImpl {
public:
- DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) {
+ DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph)
+ : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) {
device_streams_.resize(num_streams, nullptr);
owned_streams_.reserve(num_streams);
root_stream_ = std::make_unique(nullptr, root_stream_device_);
@@ -32,13 +33,16 @@ class DeviceStreamCollectionImpl {
void ReleaseSingleStreamBuffers(Stream* stream) {
if (!stream) return;
- for (auto it : allocators_) {
+ for (const auto& it : allocators_) {
if (it.second->Info().device == stream->GetDevice() &&
it.second->Info().alloc_type == OrtArenaAllocator) {
- auto* arena_alloc = static_cast(it.second.get());
- auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_alloc);
- if (stream_aware_alloc) {
- stream_aware_alloc->ReleaseStreamBuffers(stream);
+ if (it.second->IsStreamAware()) {
+ // Previously we only had one StreamAwareBFCArena. We need to guard
+ // against multiple allocators now.
+ auto* arena_alloc = IArena::SafeArenaCast(it.second.get());
+ if (arena_alloc) {
+ arena_alloc->ReleaseStreamBuffers(stream);
+ }
}
}
}
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 01ba492eb166e..8fb3dc63aa4d1 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -588,7 +588,7 @@ Status SessionState::PrepackConstantInitializedTensors(
// within this session. Or if the weight is not present on disk,
// we store the newly minted pre-packed data.
- AllocatorPtr session_cpu_alloc = GetAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault));
+ AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault));
PrePackedWeights weights_to_be_filled_in;
// The reason we invoke PrePack() before looking into the container for any pre-packed weight
// cached by another instance of the same op_type (for the same constant initializer) is because
@@ -596,7 +596,7 @@ Status SessionState::PrepackConstantInitializedTensors(
// pre-packed weight with the pre-packed weight generated by this instance of the same op_type because
// other static properties of the node like node attributes could play a role in the pre-packed
// weights' contents.
- ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc,
+ ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_initializer_alloc,
is_packed,
&weights_to_be_filled_in));
diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
index fb3b1d1d29eec..487e1533f5967 100644
--- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
+++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
@@ -332,8 +332,8 @@ static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const
const float* weights, const float* bias,
MLAS_THREADPOOL* ThreadPool)
{
- //cache of prepacked kai rhs weights and biases
- static std::unordered_map> rhs_cache;
+ // Cache of prepacked kai rhs weights and biases. thread_local to prevent interference from parallel sessions.
+ thread_local std::unordered_map> rhs_cache;
RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) };
@@ -474,8 +474,8 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s
auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool);
- //cache of computed lhs ptr offsets
- static std::unordered_map> lhs_ptrs_cache;
+ // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
+ thread_local std::unordered_map> lhs_ptrs_cache;
std::shared_ptr lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc
index 34b6b2de64a92..aeddef0c5188f 100644
--- a/onnxruntime/core/platform/posix/env.cc
+++ b/onnxruntime/core/platform/posix/env.cc
@@ -556,6 +556,8 @@ class PosixEnv : public Env {
}
PathString GetRuntimePath() const override {
+// In AIX, dladdr is not supported.
+#if !defined(_AIX)
// Use dladdr() to look up the file that contains an address from this binary.
const void* const address_from_this_binary = reinterpret_cast(Env::Default);
@@ -568,7 +570,7 @@ class PosixEnv : public Env {
runtime_path.remove_filename();
return runtime_path;
}
-
+#endif
return PathString{};
}
diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc
index f3c6b18f8e753..dc2cec1852fed 100644
--- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc
+++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc
@@ -28,8 +28,10 @@ ONNX_OPERATOR_KERNEL_EX(
10,
kCpuExecutionProvider,
KernelDefBuilder()
- .TypeConstraint("T1", DataTypeImpl::GetTensorType())
- .TypeConstraint("T2", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T1", {DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
+ .TypeConstraint("T2", {DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType()),
ConvInteger);
@@ -43,12 +45,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
if (num_inputs >= 3 && input_defs[2]->Exists()) {
const auto* X_Zero_Point = context->Input(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
- input_offset = *(X_Zero_Point->Data());
+ input_offset = *static_cast(X_Zero_Point->DataRaw());
}
if (num_inputs >= 4 && input_defs[3]->Exists()) {
const auto* W_Zero_Point = context->Input(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
- filter_offset = *(W_Zero_Point->Data());
+ filter_offset = *static_cast(W_Zero_Point->DataRaw());
}
const int64_t N = X->Shape()[0];
@@ -110,45 +112,82 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
- const auto* Xdata = X->Data();
- const auto* Wdata = W->Data();
+ const auto* Xdata = static_cast(X->DataRaw());
+ const auto* Wdata = static_cast(W->DataRaw());
+ bool X_is_signed = X->IsDataType();
auto* Ydata = Y->MutableData();
for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (col_buffer_data != nullptr) {
if (kernel_rank == 2) {
- math::Im2col()(
- Xdata,
- C / conv_attrs_.group,
- input_shape[0],
- input_shape[1],
- kernel_shape[0],
- kernel_shape[1],
- dilations[0],
- dilations[1],
- pads[0],
- pads[1],
- pads[2],
- pads[3],
- strides[0],
- strides[1],
- col_buffer_data,
- input_offset);
+ if (X_is_signed) {
+ math::Im2col()(
+ reinterpret_cast(Xdata),
+ C / conv_attrs_.group,
+ input_shape[0],
+ input_shape[1],
+ kernel_shape[0],
+ kernel_shape[1],
+ dilations[0],
+ dilations[1],
+ pads[0],
+ pads[1],
+ pads[2],
+ pads[3],
+ strides[0],
+ strides[1],
+ reinterpret_cast(col_buffer_data),
+ static_cast(input_offset));
+ } else {
+ math::Im2col()(
+ Xdata,
+ C / conv_attrs_.group,
+ input_shape[0],
+ input_shape[1],
+ kernel_shape[0],
+ kernel_shape[1],
+ dilations[0],
+ dilations[1],
+ pads[0],
+ pads[1],
+ pads[2],
+ pads[3],
+ strides[0],
+ strides[1],
+ col_buffer_data,
+ input_offset);
+ }
} else {
- math::Im2col()(
- Xdata,
- input_shape.GetDims().data(),
- output_shape.GetDims().data(),
- kernel_dim,
- kernel_shape.data(),
- strides.data(),
- dilations.data(),
- pads.data(),
- static_cast(kernel_rank),
- col_buffer_data,
- false,
- input_offset);
+ if (X_is_signed) {
+ math::Im2col()(
+ reinterpret_cast(Xdata),
+ input_shape.GetDims().data(),
+ output_shape.GetDims().data(),
+ kernel_dim,
+ kernel_shape.data(),
+ strides.data(),
+ dilations.data(),
+ pads.data(),
+ static_cast(kernel_rank),
+ reinterpret_cast(col_buffer_data),
+ false,
+ static_cast(input_offset));
+ } else {
+ math::Im2col()(
+ Xdata,
+ input_shape.GetDims().data(),
+ output_shape.GetDims().data(),
+ kernel_dim,
+ kernel_shape.data(),
+ strides.data(),
+ dilations.data(),
+ pads.data(),
+ static_cast(kernel_rank),
+ col_buffer_data,
+ false,
+ input_offset);
+ }
}
}
@@ -156,12 +195,14 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
gemm_shape.M = static_cast(M / conv_attrs_.group);
gemm_shape.N = static_cast(output_image_size);
gemm_shape.K = static_cast(kernel_dim);
+ gemm_shape.AIsSigned = W->IsDataType();
+ gemm_shape.BIsSigned = X_is_signed;
MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
- gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
+ gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
gemm_params.ldb = static_cast(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 3816cc1f8f6b9..eff0801a00460 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -14,6 +14,7 @@
#include "core/providers/cuda/cuda_fwd.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/providers/cuda/cuda_profiler.h"
+#include "core/providers/cuda/cuda_mempool_arena.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#ifndef USE_CUDA_MINIMAL
@@ -134,11 +135,10 @@ ONNX_OPERATOR_KERNEL_EX(
} // namespace cuda
-AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId device_id,
- size_t gpu_mem_limit,
- ArenaExtendStrategy arena_extend_strategy,
- CUDAExecutionProviderExternalAllocatorInfo external_allocator_info,
- const OrtArenaCfg* default_memory_arena_cfg) {
+AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params) {
+ ORT_ENFORCE(cuda_allocator_params.external_alloc_info != nullptr,
+ "CUDAAllocatorParams.external_alloc_info is nullptr.");
+ const auto& external_allocator_info = *(cuda_allocator_params.external_alloc_info);
if (external_allocator_info.UseExternalAllocator()) {
AllocatorCreationInfo default_memory_info(
[external_allocator_info](OrtDevice::DeviceId id) {
@@ -147,24 +147,59 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi
external_allocator_info.free,
external_allocator_info.empty_cache);
},
- device_id,
+ cuda_allocator_params.device_id,
false);
return CreateAllocator(default_memory_info);
} else {
- AllocatorCreationInfo default_memory_info(
- [](OrtDevice::DeviceId id) {
- return std::make_unique(id, CUDA);
- },
- device_id,
- true,
- {default_memory_arena_cfg ? *default_memory_arena_cfg
- : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)},
- // make it stream aware
- true);
-
- // CUDA malloc/free is expensive so always use an arena
- return CreateAllocator(default_memory_info);
+ const auto* arena_cfg = cuda_allocator_params.arena_cfg;
+ const bool cuda_mempool_requested = arena_cfg != nullptr && arena_cfg->use_cuda_mempool == 1;
+ bool use_cuda_mempool = cuda_mempool_requested && cuda::CudaMempoolArena::IsCudaVersionSupported();
+
+ if (cuda_mempool_requested && !use_cuda_mempool) {
+ LOGS_DEFAULT(WARNING)
+ << "CUDA memory pool requested but not supported on this device/driver."
+ << "Falling back to default BFCArena with CUDA allocator.";
+ }
+
+ if (use_cuda_mempool) {
+ const bool cuda_graph_enabled = cuda_allocator_params.provider_info != nullptr &&
+ cuda_allocator_params.provider_info->enable_cuda_graph;
+
+ if (cuda_graph_enabled) {
+ LOGS_DEFAULT(WARNING)
+ << "CUDA Mempool Arena allocator is not compatible with requested CUDA Graph Capture"
+ << "Falling back to default BFCArena with CUDA allocator.";
+ use_cuda_mempool = false;
+ }
+ }
+
+ if (use_cuda_mempool) {
+ auto device = OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
+ cuda_allocator_params.device_id);
+ auto mem_info = OrtMemoryInfo("CUDAMemPoolArena", OrtAllocatorType::OrtArenaAllocator, device, OrtMemTypeDefault);
+
+ auto mempool_allocator = std::make_shared(mem_info,
+ arena_cfg->cuda_mempool_release_threshold,
+ arena_cfg->cuda_mempool_bytes_to_keep_on_shrink,
+ cuda_allocator_params.logger);
+
+ return mempool_allocator;
+ } else {
+ AllocatorCreationInfo default_memory_info(
+ [](OrtDevice::DeviceId id) {
+ return std::make_unique(id, CUDA);
+ },
+ cuda_allocator_params.device_id,
+ true,
+ {arena_cfg ? *arena_cfg
+ : OrtArenaCfg(cuda_allocator_params.cuda_mem_threshold,
+ static_cast(cuda_allocator_params.arena_extend_strategy), -1, -1, -1, -1L)},
+ // make it stream aware
+ true);
+ // CUDA malloc/free is expensive so always use an arena
+ return CreateAllocator(default_memory_info);
+ }
}
}
@@ -3044,9 +3079,18 @@ std::vector CUDAExecutionProvider::CreatePreferredAllocators() {
return std::make_unique(device_id, CUDA_PINNED);
},
info_.device_id);
+
+ CUDAExecutionProvider::CUDAAllocatorParams params{};
+ params.device_id = info_.device_id;
+ params.cuda_mem_threshold = info_.gpu_mem_limit;
+ params.arena_extend_strategy = info_.arena_extend_strategy;
+ params.provider_info = &info_;
+ params.external_alloc_info = &info_.external_allocator_info;
+ params.arena_cfg = info_.default_memory_arena_cfg;
+ params.logger = GetLogger();
+
return std::vector{
- CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy,
- info_.external_allocator_info, info_.default_memory_arena_cfg),
+ CreateCudaAllocator(params),
CreateAllocator(pinned_memory_info),
};
}
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
index 57fde8146d929..751bbb90f8619 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
@@ -103,8 +103,17 @@ class CUDAExecutionProvider : public IExecutionProvider {
return CUDAExecutionProviderInfo::ToProviderOptions(info_);
}
- static AllocatorPtr CreateCudaAllocator(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
- CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg);
+ struct CUDAAllocatorParams {
+ OrtDevice::DeviceId device_id = 0;
+ size_t cuda_mem_threshold = std::numeric_limits::max();
+ ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo;
+ const CUDAExecutionProviderInfo* provider_info = nullptr;
+ const CUDAExecutionProviderExternalAllocatorInfo* external_alloc_info = nullptr;
+ const OrtArenaCfg* arena_cfg = nullptr;
+ const logging::Logger* logger = nullptr;
+ };
+
+ static AllocatorPtr CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params);
ITuningContext* GetTuningContext() const override;
diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc
new file mode 100644
index 0000000000000..802867ec0d89b
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc
@@ -0,0 +1,272 @@
+// Copyright (c) Microsoft.
+// Licensed under the MIT License.
+
+#include "cuda_mempool_arena.h"
+
+#include
+
+#include "core/providers/cuda/shared_inc/cuda_call.h" // ORT CudaCall helpers
+#include "core/providers/shared_library/provider_api.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+// ======== CudaMempoolArena ========
+
+CudaMempoolArena::CudaMempoolArena(const OrtMemoryInfo& memory_info,
+ uint64_t pool_release_threshold,
+ size_t bytes_to_keep_on_shrink,
+ const logging::Logger* logger)
+ : IArena(memory_info),
+ pool_release_threshold_(pool_release_threshold),
+ bytes_to_keep_on_shrink_(bytes_to_keep_on_shrink),
+ logger_(logger) {
+ if (logger_ == nullptr) {
+ logger_ = &::onnxruntime::logging::LoggingManager::DefaultLogger();
+ }
+
+ // Create a process-local device memory pool for device_id_.
+ // 'cudaMemAllocationTypeDevice' (for cudaMemPoolProps.allocType) not clear when it is available
+
+ cudaMemPoolProps props{};
+ // Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist
+ // even though is present in some internet docs.
+ props.allocType = cudaMemAllocationTypePinned;
+ props.handleTypes = cudaMemHandleTypeNone; // local to process
+ props.location.type = cudaMemLocationTypeDevice; // Device memory
+ props.location.id = this->Info().device.Id();
+
+ CUDA_CALL_THROW(cudaMemPoolCreate(&pool_, &props));
+
+ if (pool_release_threshold_ != 0) {
+ CUDA_CALL_THROW(cudaMemPoolSetAttribute(pool_, cudaMemPoolAttrReleaseThreshold,
+ &pool_release_threshold_));
+ }
+
+ LOGS(*logger_, INFO) << "CudaMempoolArena created on device " << this->Info().device.Id()
+ << " with pool_release_threshold=" << pool_release_threshold_
+ << " bytes_to_keep_on_shrink=" << bytes_to_keep_on_shrink_ << ".";
+
+ // Intentionally DO NOT call cudaDeviceSetMemPool(device_id_, pool_);
+ // All allocations explicitly target this pool via cudaMallocFromPoolAsync.
+}
+
+CudaMempoolArena::~CudaMempoolArena() {
+ // 1) Best-effort: enqueue frees for any remaining allocations on their recorded streams.
+ // No locking by design: destruction implies no concurrent access.
+ for (auto& kv : alloc_map_) {
+ void* p = kv.first;
+ const cudaStream_t s = kv.second.stream;
+ ORT_IGNORE_RETURN_VALUE(cudaFreeAsync(p, s)); // ignore errors in destructor
+ }
+
+ // 2) Synchronize all streams we know about (those that ever held allocations).
+ SyncAllKnownStreams_NoThrow();
+
+ // Now it is safe to drop our bookkeeping.
+ alloc_map_.clear();
+ stream_map_.clear();
+
+ // 3) Safety barrier: ensure any frees enqueued on destroyed/unknown streams are completed.
+ ORT_IGNORE_RETURN_VALUE(cudaDeviceSynchronize()); // ignore errors in destructor
+
+ // 4) Trim to zero and destroy the pool.
+ if (pool_) {
+ ORT_IGNORE_RETURN_VALUE(cudaMemPoolTrimTo(pool_, 0)); // best-effort
+ ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool_));
+ pool_ = nullptr;
+ }
+}
+
+void* CudaMempoolArena::Alloc(size_t size) {
+ if (size == 0) return nullptr;
+
+ void* p = nullptr;
+ constexpr const cudaStream_t kDefaultStream = static_cast(0);
+ cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, kDefaultStream);
+ if (err != cudaSuccess) {
+ ORT_THROW("CudaMempoolArena::Alloc: cudaMallocFromPoolAsync failed: ",
+ cudaGetErrorString(err), " (", static_cast(err), "), size=", size);
+ }
+
+ LOGS(*logger_, VERBOSE) << "CudaMempoolArena::Alloc: allocated "
+ << size << " bytes at " << p << " on default stream.";
+
+ // In case the default stream is busy.
+ ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(kDefaultStream));
+
+ {
+ std::lock_guard lock(mutex_);
+ AllocationRecord rec{size, kDefaultStream};
+ alloc_map_.emplace(p, rec);
+ stream_map_[kDefaultStream].insert(p);
+
+ total_allocated_ += size;
+ in_use_bytes_ += size;
+ max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_);
+ max_alloc_size_ = std::max(max_alloc_size_, size);
+ ++num_allocs_;
+ }
+
+ return p;
+}
+
+void* CudaMempoolArena::AllocOnStream(size_t size, Stream* stream) {
+ if (size == 0) return nullptr;
+
+ void* p = nullptr;
+ const cudaStream_t s = ResolveCudaStream(stream);
+
+ cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, s);
+ if (err != cudaSuccess) {
+ ORT_THROW("CudaMempoolArena::AllocOnStream: cudaMallocFromPoolAsync failed on stream=",
+ reinterpret_cast(s), ": ",
+ cudaGetErrorString(err), " (", static_cast(err), "), size=", size);
+ }
+
+ LOGS(*logger_, VERBOSE) << "CudaMempoolArena::AllocOnStream: allocated "
+ << size << " bytes at " << p << " on stream "
+ << reinterpret_cast(s) << ".";
+
+ {
+ std::lock_guard lock(mutex_);
+ AllocationRecord rec{size, s};
+ alloc_map_.emplace(p, rec);
+ stream_map_[s].insert(p);
+
+ total_allocated_ += size;
+ in_use_bytes_ += size;
+ max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_);
+ max_alloc_size_ = std::max(max_alloc_size_, size);
+ ++num_allocs_;
+ }
+
+ return p;
+}
+
+void CudaMempoolArena::Free(void* p) {
+ if (!p) return;
+
+ cudaStream_t s = static_cast(0);
+ size_t sz = 0;
+
+ {
+ std::lock_guard lock(mutex_);
+ auto it = alloc_map_.find(p);
+ if (it == alloc_map_.end()) {
+ // Not owned by this allocator; ignore per ORT convention.
+ LOGS(*logger_, WARNING) << "CudaMempoolArena::Free: pointer "
+ << p << " not found in allocation map; ignoring.";
+ return;
+ }
+
+ s = it->second.stream;
+ sz = it->second.bytes;
+
+ alloc_map_.erase(it);
+
+ auto sit = stream_map_.find(s);
+ if (sit != stream_map_.end()) {
+ sit->second.erase(p);
+ if (sit->second.empty()) {
+ stream_map_.erase(sit);
+ }
+ }
+
+ in_use_bytes_ = (sz <= in_use_bytes_) ? (in_use_bytes_ - sz) : 0;
+ }
+
+ // Ordered free on the stream that allocated p
+ CUDA_CALL_THROW(cudaFreeAsync(p, s));
+}
+
+Status CudaMempoolArena::Shrink() {
+ // Trim the pool; live allocations are not affected.
+ ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemPoolTrimTo(pool_, bytes_to_keep_on_shrink_)));
+
+ size_t current_in_use = 0;
+ ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrUsedMemCurrent,
+ ¤t_in_use)));
+
+ // Query current reserved size. cudaMemPoolAttrReservedMemCurrent
+ size_t reserved_size = 0;
+ if (CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrReservedMemCurrent,
+ &reserved_size))
+ .IsOK()) {
+ LOGS(*logger_, INFO) << "CudaMempoolArena::Shrink: pool current_in_use: " << current_in_use
+ << " reserved size after trim : " << reserved_size << " bytes.";
+ } else {
+ LOGS(*logger_, INFO) << "CudaMempoolArena pool has been shrunk; unable to query reserved size.";
+ }
+
+ // Right-size maps under lock.
+ std::lock_guard lock(mutex_);
+ MaybeRehashLocked();
+ ++num_arena_shrinkages_;
+ return Status::OK();
+}
+
+void CudaMempoolArena::GetStats(AllocatorStats* stats) {
+ if (!stats) return;
+ std::lock_guard lock(mutex_);
+ stats->num_allocs = num_allocs_;
+ stats->total_allocated_bytes = total_allocated_;
+ stats->bytes_in_use = in_use_bytes_;
+ stats->max_bytes_in_use = max_bytes_in_use_;
+ stats->num_arena_shrinkages = num_arena_shrinkages_;
+}
+
+cudaStream_t CudaMempoolArena::ResolveCudaStream(Stream* stream) noexcept {
+ if (!stream) return static_cast(0);
+ return static_cast(stream->GetHandle());
+}
+
+void CudaMempoolArena::MaybeRehashLocked() {
+ const size_t alloc_sz = alloc_map_.size();
+ const size_t stream_sz = stream_map_.size();
+ if (alloc_sz > 0) alloc_map_.reserve(alloc_sz);
+ if (stream_sz > 0) stream_map_.reserve(stream_sz);
+}
+
+void CudaMempoolArena::SyncAllKnownStreams_NoThrow() {
+ for (const auto& kv : stream_map_) {
+ const cudaStream_t s = kv.first;
+ ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(s)); // ignore errors; device-wide sync follows
+ }
+}
+
+bool CudaMempoolArena::IsCudaVersionSupported() noexcept {
+ int ort_cuda_rt_version = 0;
+ cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version);
+ if (cuda_status != cudaSuccess) {
+ return false;
+ }
+
+ if (ort_cuda_rt_version < 11020) {
+ return false;
+ }
+
+ int ort_cuda_driver_version = 0;
+ cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version);
+ if (cuda_status != cudaSuccess) {
+ return false;
+ }
+
+ if (ort_cuda_driver_version < 11020) {
+ return false;
+ }
+
+ // Check if the driver version supports the runtime version
+ if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) {
+ return false;
+ }
+
+ if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h
new file mode 100644
index 0000000000000..750cbaf93b6d4
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h
@@ -0,0 +1,177 @@
+// Copyright (c) Microsoft.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "core/common/common.h" // ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE, ORT_THROW/ENFORCE
+#include "core/common/inlined_containers.h" // InlinedHashMap, InlinedHashSet, InlinedVector
+#include "core/providers/cuda/cuda_stream_handle.h" // ORT Stream -> cudaStream_t
+#include "core/providers/shared_library/provider_api.h"
+
+namespace onnxruntime {
+namespace logging {
+class Logger;
+}
+namespace cuda {
+/**
+ * @brief Stream-aware CUDA allocator implemented on top of a private `cudaMemPool_t`.
+ * The purpose of this arena is to assist with memory allocations in environments where
+ * a single process is hosting more than one cuda session. This arena hosts cuda memory pool
+ * which has some tunable parameters to control its memory usage and de-allocates memory back to
+ * the device according to the specified params. This is opposite to the BFCArena which only
+ * attempts to free memory on Shrink() if configured at the end of the run.
+ *
+ * ### Behavior
+ * - Creates a **process-local** CUDA mempool for a specific device (from `OrtMemoryInfo`).
+ * - All allocations use **`cudaMallocFromPoolAsync()`** on either the legacy default stream (0) or a
+ * caller-provided stream. The allocation stream is recorded for ordered free.
+ * - `Free()` enqueue **`cudaFreeAsync()`** on the recorded stream to
+ * respect CUDA's stream-ordered semantics.
+ * - `Shrink()` trims the pool with **`cudaMemPoolTrimTo(bytes_to_keep)`** and right-sizes the book-keeping maps
+ * under lock.
+ *
+ * ### Tuning
+ * - `pool_release_threshold`: if non-zero, sets `cudaMemPoolAttrReleaseThreshold`. **Recommended: 1 MB.**, but
+ * must be experimentally determined based on workload for optimal memory consumption vs performance.
+ * `cudaMemPoolAttrReservedMemCurrent`. **Recommended: 10 MB.**
+ * - `bytes_to_keep_on_shrink`: target size for `cudaMemPoolTrimTo()` on `Shrink()`. This is only relevant
+ * if Shrink() is enabled. It usually costs performance, and strictly speaking is not necessary for cuda mempools
+ * since they release memory on at synchronous points according to `pool_release_threshold`.
+ *
+ * ### Thread-safety
+ * - All updates to internal maps and statistics are guarded by an internal `std::mutex`.
+ *
+ * @note The allocator **does not** set the device default mempool and **does not** switch the current device.
+ */
+class CudaMempoolArena final : public IArena {
+ public:
+ /**
+ * @brief Construct a `CudaMempoolArena` with a private CUDA mempool.
+ *
+ * @param memory_info `OrtMemoryInfo` whose device id selects the CUDA device.
+ * @param pool_release_threshold Optional release threshold (bytes) for `cudaMemPoolAttrReleaseThreshold`.
+ * If 0, the attribute is not set. **Recommended value: 1 MB.**
+ * @param bytes_to_keep_on_shrink Target size (bytes) for `cudaMemPoolTrimTo()` on `Shrink()`.
+ * @param logger Cuda EP Logger
+ *
+ * The created pool is process-local and is **not** set as the device default pool.
+ */
+ CudaMempoolArena(const OrtMemoryInfo& memory_info,
+ uint64_t pool_release_threshold,
+ size_t bytes_to_keep_on_shrink,
+ const logging::Logger* logger);
+
+ /**
+ * @brief Destructor:
+ * 1) Enqueues cudaFreeAsync() for any outstanding allocations.
+ * 2) Synchronizes all known streams (best-effort; ignores invalid handles).
+ * 3) Calls cudaDeviceSynchronize() as a final barrier to ensure queued frees complete.
+ * 4) Trims pool to zero and destroys it.
+ */
+ ~CudaMempoolArena() override;
+
+ // -------- IAllocator overrides --------
+
+ /**
+ * @brief Allocate @p size bytes using the legacy default CUDA stream (0).
+ * @return device pointer or nullptr when size == 0
+ * @throws on allocation failure
+ */
+ void* Alloc(size_t size) override;
+
+ /**
+ * @brief Allocate @p size bytes on the given ORT stream (uses `cudaMallocFromPoolAsync`).
+ * @return device pointer or nullptr when size == 0
+ * @throws on allocation failure
+ */
+ void* AllocOnStream(size_t size, Stream* stream) override;
+
+ /**
+ * @brief Enqueue an ordered async free on the stream that allocated @p p.
+ * No-op if @p p is null or not owned by this allocator.
+ */
+ void Free(void* p) override;
+
+ /**
+ * @brief Reserve @p size bytes; implemented in terms of `Alloc(size)`.
+ * This is done so all the memory is gone including initializers when
+ * the session is torn down.
+ * @return device pointer or nullptr when size == 0
+ * @throws on allocation failure
+ */
+ void* Reserve(size_t size) override { return Alloc(size); }
+
+ /// @brief This allocator is stream-aware.
+ bool IsStreamAware() const override { return true; }
+
+ /// @brief Populate basic allocation statistics.
+ void GetStats(AllocatorStats* stats) override;
+
+ // -------- IArena overrides --------
+
+ /**
+ * @brief Enqueue `cudaFreeAsync()` for all allocations made on @p stream.
+ * we intentionally do not implement this method. The call to this method
+ * will yank memory from under live OrtValues such as allocated for output
+ * bound and the resulting output OrtValue will not be valid.
+ * Then when the OrtValues attempt to release memory those entries are not found
+ * in the map: CudaMempoolArena::Free: pointer 0000000203800400 not found in allocation map; ignoring
+ * The reason this works with BFCArena is because it does not really release memory.
+ */
+ // void ReleaseStreamBuffers(Stream* stream) override;
+
+ /**
+ * @brief Trim the pool to `bytes_to_keep_on_shrink_` (configured at construction) using `cudaMemPoolTrimTo()`.
+ * Memory still allocated is not affected. Shrink() may affect your performance and contrary to BFCArena
+ * This allocator does not need Shrink. Cuda mempool is capable of releasing memory automatically
+ * according to pool_release_threshold_ set at construction.
+ * Also rehashes internal maps under lock to keep them reasonably sized.
+ */
+ Status Shrink() override;
+
+ static bool IsCudaVersionSupported() noexcept;
+
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudaMempoolArena);
+
+ private:
+ /// Convert ORT `Stream*` to native `cudaStream_t`; null means legacy default (0).
+ static cudaStream_t ResolveCudaStream(Stream* stream) noexcept;
+
+ /// Rehash internal maps under lock; invoked only by `Shrink()`.
+ void MaybeRehashLocked();
+
+ /// Best-effort synchronization of all streams in stream_map_. Non-throwing; ignores errors.
+ void SyncAllKnownStreams_NoThrow();
+
+ struct AllocationRecord {
+ size_t bytes;
+ cudaStream_t stream; // stream on which allocation/free are ordered
+ };
+
+ // ---- Pool/context configuration (immutable) ----
+ uint64_t pool_release_threshold_;
+ size_t bytes_to_keep_on_shrink_;
+ size_t initial_pool_size_bytes_;
+ const logging::Logger* logger_;
+ cudaMemPool_t pool_{nullptr};
+
+ // ---- Bookkeeping (guarded by mutex_) ----
+ std::mutex mutex_;
+ InlinedHashMap alloc_map_; // ptr -> record
+ InlinedHashMap> stream_map_; // stream -> ptrs
+
+ // ---- Stats (guarded by mutex_) ----
+ size_t total_allocated_ = 0;
+ size_t in_use_bytes_ = 0;
+ size_t max_bytes_in_use_ = 0;
+ size_t num_allocs_ = 0;
+ size_t num_arena_shrinkages_ = 0;
+ size_t max_alloc_size_ = 0;
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
index 3b361f155831b..70afba320576b 100644
--- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
+++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
@@ -181,7 +181,13 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA {
}
std::shared_ptr CreateCudaAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::CUDAExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {
- return CUDAExecutionProvider::CreateCudaAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);
+ CUDAExecutionProvider::CUDAAllocatorParams params{};
+ params.device_id = device_id;
+ params.cuda_mem_threshold = gpu_mem_limit;
+ params.arena_extend_strategy = arena_extend_strategy;
+ params.external_alloc_info = &external_allocator_info;
+ params.arena_cfg = default_memory_arena_cfg;
+ return CUDAExecutionProvider::CreateCudaAllocator(params);
}
} g_info;
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc
new file mode 100644
index 0000000000000..619e3eaf5fad4
--- /dev/null
+++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc
@@ -0,0 +1,480 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "core/providers/qnn/ort_api.h"
+#include "core/providers/qnn/builder/qnn_utils.h"
+#include "core/providers/qnn/builder/op_builder_factory.h"
+#include "core/providers/qnn/builder/qnn_model_wrapper.h"
+#include "core/providers/qnn/builder/qnn_node_group/utils.h"
+
+namespace onnxruntime {
+namespace qnn {
+
+// Helper function to extract value from raw data based on QNN data type
+static Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type,
+ const uint8_t* raw_ptr,
+ double& value) {
+ switch (qnn_data_type) {
+ case QNN_DATATYPE_INT_8:
+ case QNN_DATATYPE_SFIXED_POINT_8: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_INT_16:
+ case QNN_DATATYPE_SFIXED_POINT_16: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_INT_32:
+ case QNN_DATATYPE_SFIXED_POINT_32: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_UINT_8:
+ case QNN_DATATYPE_UFIXED_POINT_8: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_UINT_16:
+ case QNN_DATATYPE_UFIXED_POINT_16: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_UINT_32:
+ case QNN_DATATYPE_UFIXED_POINT_32: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_FLOAT_32: {
+ value = static_cast(*reinterpret_cast(raw_ptr));
+ break;
+ }
+ case QNN_DATATYPE_FLOAT_16: {
+ value = static_cast(reinterpret_cast(raw_ptr)->ToFloat());
+ break;
+ }
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " not supported.");
+ }
+ return Status::OK();
+}
+
+// Helper function to extract a scalar float value from a constant initializer
+// Handles both float and quantized (INT type) constant inputs
+static std::optional GetConstantInitializerFloatScalar(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnitIODef& io_def) {
+ const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
+ const auto& name = io_def.node_arg.Name();
+
+ if (!graph_viewer.IsConstantInitializer(name, true)) {
+ return std::nullopt;
+ }
+
+ // Get tensor info to check if it's quantized
+ TensorInfo tensor_info = {};
+ if (!qnn_model_wrapper.GetTensorInfo(io_def, tensor_info).IsOK()) {
+ return std::nullopt;
+ }
+
+ // Must be an initializer
+ if (!tensor_info.is_initializer || !tensor_info.initializer_tensor) {
+ return std::nullopt;
+ }
+
+ // Unpack the initializer data
+ std::vector unpacked_tensor;
+ if (!qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor).IsOK()) {
+ return std::nullopt;
+ }
+
+ if (unpacked_tensor.empty()) {
+ return std::nullopt;
+ }
+
+ // Extract the value using GetValueOnQnnDataType
+ double extracted_value = 0.0;
+ if (!GetValueOnQnnDataType(tensor_info.qnn_data_type, unpacked_tensor.data(), extracted_value).IsOK()) {
+ return std::nullopt;
+ }
+
+ // Check if quantized and dequantize if needed
+ const bool is_quantized = tensor_info.quant_param.IsQuantized();
+ if (is_quantized) {
+ // For quantized tensors, dequantize the value
+ if (!tensor_info.quant_param.IsPerTensor()) {
+ return std::nullopt; // Only support per-tensor quantization
+ }
+
+ const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get();
+ double dequantized_value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset,
+ quant_param.scaleOffsetEncoding.scale,
+ extracted_value);
+ return static_cast(dequantized_value);
+ }
+
+ // For non-quantized tensors, return the extracted value directly
+ return static_cast(extracted_value);
+}
+
+// Helper function to check if a constant initializer has the expected float value
+static bool IsInitializerWithExpectedValue(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnitIODef& io_def,
+ float expected_value,
+ float tolerance = 1e-5f) {
+ std::optional actual_value = GetConstantInitializerFloatScalar(qnn_model_wrapper, io_def);
+ if (!actual_value.has_value()) {
+ return false;
+ }
+
+ // Compare with expected value within tolerance
+ return std::fabs(actual_value.value() - expected_value) <= tolerance;
+}
+
+// Forward declaration.
+static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
+ gsl::span node_units,
+ const NodeUnitIODef& root_input,
+ const NodeUnitIODef& final_output,
+ bool validate);
+
+// Helper function to validate on QNN
+static Status ValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
+ gsl::span node_units,
+ const NodeUnitIODef& root_input,
+ const NodeUnitIODef& final_output) {
+ return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, true);
+}
+
+// Helper function to create on QNN
+static Status CreateOnQnn(QnnModelWrapper& qnn_model_wrapper,
+ gsl::span node_units,
+ const NodeUnitIODef& root_input,
+ const NodeUnitIODef& final_output) {
+ return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, false);
+}
+
+// Gets the parent and child of the Erf node. Can handle the following sequences
+// - Parent -> Erf -> Child.
+// - Parent -> DQ -> Erf -> Q -> Child.
+//
+// Also returns the outputs of the Erf. For the sequence `DQ -> Erf -> Q`, returns the outputs of the Q.
+static bool GetErfParentAndChild(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& erf_node_unit,
+ const std::unordered_map& node_to_node_unit,
+ const std::unordered_map& node_unit_to_qnn_node_group,
+ /*out*/ const NodeUnit*& parent_node_unit,
+ /*out*/ const NodeUnit*& child_node_unit,
+ /*out*/ const NodeUnit*& dq_node_unit,
+ /*out*/ const NodeUnit*& q_node_unit,
+ /*out*/ gsl::span& erf_outputs) {
+ const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
+
+ auto get_first_parent = [&](const NodeUnit& node_unit) -> const NodeUnit* {
+ const auto& inputs = node_unit.Inputs();
+ if (inputs.empty()) {
+ return nullptr;
+ }
+ return GetParentOfInput(graph_viewer, node_unit, inputs[0],
+ node_to_node_unit, node_unit_to_qnn_node_group);
+ };
+
+ auto get_first_child = [&](const NodeUnit& node_unit) -> const NodeUnit* {
+ const auto& outputs = node_unit.Outputs();
+ if (outputs.empty()) {
+ return nullptr;
+ }
+
+ return GetOnlyChildOfOutput(graph_viewer, node_unit, outputs[0],
+ node_to_node_unit, node_unit_to_qnn_node_group);
+ };
+
+ const NodeUnit* erf_parent_node_unit = get_first_parent(erf_node_unit);
+ if (erf_parent_node_unit == nullptr) {
+ return false;
+ }
+
+ const NodeUnit* erf_child_node_unit = get_first_child(erf_node_unit);
+ if (erf_child_node_unit == nullptr) {
+ return false;
+ }
+
+ if (erf_node_unit.UnitType() == NodeUnit::Type::SingleNode &&
+ erf_parent_node_unit->OpType() == "DequantizeLinear" &&
+ erf_child_node_unit->OpType() == "QuantizeLinear") {
+ // This is the explicit sequence DQ -> Erf -> Q.
+ // Look past the DQ and Q nodes to get the actual parent and child.
+ // We do this because ORT utils do not automatically group DQ -> Erf -> Q into a NodeUnit.
+ dq_node_unit = erf_parent_node_unit;
+ q_node_unit = erf_child_node_unit;
+ erf_parent_node_unit = get_first_parent(*erf_parent_node_unit);
+ erf_child_node_unit = get_first_child(*erf_child_node_unit);
+
+ erf_outputs = q_node_unit->Outputs();
+ } else {
+ erf_outputs = erf_node_unit.Outputs();
+ }
+
+ parent_node_unit = erf_parent_node_unit;
+ child_node_unit = erf_child_node_unit;
+ return parent_node_unit != nullptr && child_node_unit != nullptr;
+}
+
+std::unique_ptr GeluFusion::TryFusion(
+ QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& erf_node_unit,
+ const std::unordered_map& node_to_node_unit,
+ const std::unordered_map& node_unit_to_qnn_node_group,
+ const logging::Logger& /*logger*/) {
+ if (erf_node_unit.OpType() != "Erf") {
+ return nullptr;
+ }
+
+ const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
+ const NodeUnit* div_node_unit = nullptr;
+ const NodeUnit* add_node_unit = nullptr;
+ const NodeUnit* dq_node_unit = nullptr;
+ const NodeUnit* q_node_unit = nullptr;
+ gsl::span erf_outputs;
+
+ if (!GetErfParentAndChild(qnn_model_wrapper, erf_node_unit, node_to_node_unit, node_unit_to_qnn_node_group,
+ div_node_unit, add_node_unit, dq_node_unit, q_node_unit, erf_outputs)) {
+ return nullptr;
+ }
+
+ // Erf must have a Div parent.
+ if (div_node_unit == nullptr || div_node_unit->OpType() != "Div") {
+ return nullptr;
+ }
+
+ // Div must have 2 inputs
+ const auto& div_inputs = div_node_unit->Inputs();
+ if (div_inputs.size() < 2) {
+ return nullptr;
+ }
+
+ // Check second input of Div is sqrt(2) ≈ 1.4142
+ if (!IsInitializerWithExpectedValue(qnn_model_wrapper, div_inputs[1], static_cast(M_SQRT2))) {
+ return nullptr;
+ }
+
+ // Erf must have an Add child consuming its output
+ if (add_node_unit == nullptr || add_node_unit->OpType() != "Add") {
+ return nullptr;
+ }
+
+ // Add must have 2 inputs
+ const auto& add_inputs = add_node_unit->Inputs();
+ if (add_inputs.size() < 2) {
+ return nullptr;
+ }
+
+ // Check the other input node (e.g. not the Erf) is 1.0f
+ bool is_erf_first_input = (add_inputs[0].node_arg.Name() == erf_outputs[0].node_arg.Name());
+ const auto& add_const_input = add_inputs[is_erf_first_input ? 1 : 0];
+ if (!IsInitializerWithExpectedValue(qnn_model_wrapper, add_const_input, 1.0f)) {
+ return nullptr;
+ }
+
+ // Add must have a Mul child consuming its output
+ const auto& add_outputs = add_node_unit->Outputs();
+ if (add_outputs.empty()) {
+ return nullptr;
+ }
+
+ const NodeUnit* mul_node_unit = GetOnlyChildOfOutput(graph_viewer, *add_node_unit, add_outputs[0],
+ node_to_node_unit, node_unit_to_qnn_node_group);
+ if (mul_node_unit == nullptr || mul_node_unit->OpType() != "Mul") {
+ return nullptr;
+ }
+
+ // Now check which pattern we have
+ const auto& root_input_name = div_inputs[0].node_arg.Name();
+ const auto& mul_inputs = mul_node_unit->Inputs();
+
+ if (mul_inputs.size() < 2) {
+ return nullptr;
+ }
+
+ // Try to match Pattern 1: root -> Mul(0.5) -> ... -> Mul
+ // In this case, one input to the final Mul should be from a Mul node
+ const NodeUnit* mul2_node_unit = nullptr;
+
+ // Check if either input to mul_node_unit comes from a Mul node
+ for (size_t i = 0; i < 2; ++i) {
+ const auto& mul_input = mul_inputs[i];
+
+ const NodeUnit* producer_unit = GetParentOfInput(graph_viewer, *mul_node_unit, mul_input,
+ node_to_node_unit, node_unit_to_qnn_node_group);
+ if (producer_unit && producer_unit->OpType() == "Mul") {
+ const auto& mul2_inputs = producer_unit->Inputs();
+ if (mul2_inputs.size() >= 2) {
+ bool has_root_input = (mul2_inputs[0].node_arg.Name() == root_input_name ||
+ mul2_inputs[1].node_arg.Name() == root_input_name);
+ if (has_root_input) {
+ int root_index = (mul2_inputs[0].node_arg.Name() == root_input_name) ? 0 : 1;
+ const auto& mul_const_input = mul2_inputs[1 - root_index];
+
+ if (IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) {
+ mul2_node_unit = producer_unit;
+ break;
+ }
+ }
+ }
+ }
+ if (mul2_node_unit != nullptr) break;
+ }
+
+ std::vector node_units;
+ const NodeUnit* final_mul_node_unit = nullptr;
+
+ if (mul2_node_unit != nullptr) {
+ // Pattern 1: root -> Mul(0.5) -> ... -> Mul
+ if (dq_node_unit != nullptr) {
+ assert(q_node_unit != nullptr);
+ node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, mul2_node_unit,
+ mul_node_unit};
+ } else {
+ node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul2_node_unit, mul_node_unit};
+ }
+ final_mul_node_unit = mul_node_unit;
+ } else {
+ // Try Pattern 2: root -> ... -> Mul -> Mul(0.5)
+ // Check if one input to mul_node_unit is root
+ bool has_root_input = (mul_inputs[0].node_arg.Name() == root_input_name ||
+ mul_inputs[1].node_arg.Name() == root_input_name);
+
+ if (!has_root_input) {
+ return nullptr;
+ }
+
+ // mul_node_unit must have a Mul child consuming its output
+ const auto& mul_outputs = mul_node_unit->Outputs();
+ if (mul_outputs.empty()) {
+ return nullptr;
+ }
+
+ const NodeUnit* mul2_node_unit_pattern2 = GetOnlyChildOfOutput(graph_viewer, *mul_node_unit, mul_outputs[0],
+ node_to_node_unit, node_unit_to_qnn_node_group);
+ if (mul2_node_unit_pattern2 == nullptr || mul2_node_unit_pattern2->OpType() != "Mul") {
+ return nullptr;
+ }
+
+ // Verify this final Mul has 2 inputs
+ const auto& mul2_inputs = mul2_node_unit_pattern2->Inputs();
+ if (mul2_inputs.size() < 2) {
+ return nullptr;
+ }
+
+ // Check the constant input is 0.5f
+ int mul_const_input_index = 0;
+ if (mul2_inputs[0].node_arg.Name() == mul_outputs[0].node_arg.Name()) {
+ mul_const_input_index = 1;
+ }
+ const auto& mul_const_input = mul2_inputs[mul_const_input_index];
+ if (!IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) {
+ return nullptr;
+ }
+
+ // Pattern 2
+ if (dq_node_unit != nullptr) {
+ assert(q_node_unit != nullptr);
+ node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit,
+ mul_node_unit, mul2_node_unit_pattern2};
+ } else {
+ node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul_node_unit, mul2_node_unit_pattern2};
+ }
+
+ final_mul_node_unit = mul2_node_unit_pattern2;
+ }
+
+ // Validate on QNN
+ const NodeUnitIODef& root_input = div_inputs[0];
+ const NodeUnitIODef& final_output = final_mul_node_unit->Outputs()[0];
+
+ if (Status status = ValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output);
+ !status.IsOK()) {
+ return nullptr;
+ }
+
+ return std::make_unique(std::move(node_units), &erf_node_unit);
+}
+
+GeluFusion::GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit)
+ : node_units_(std::move(node_units)), target_node_unit_(target_node_unit) {
+}
+
+Status GeluFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
+ ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty");
+ const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0];
+ const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0];
+ return ValidateOnQnn(qmw, node_units_, root_input, final_output);
+}
+
+Status GeluFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
+ ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty");
+ const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0];
+ const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0];
+ return CreateOnQnn(qmw, node_units_, root_input, final_output);
+}
+
+gsl::span GeluFusion::GetNodeUnits() const {
+ return gsl::span(node_units_.data(), node_units_.size());
+}
+
+const NodeUnit* GeluFusion::GetTargetNodeUnit() const {
+ return target_node_unit_;
+}
+
+static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
+ gsl::span node_units,
+ const NodeUnitIODef& root_input,
+ const NodeUnitIODef& final_output,
+ bool validate) {
+ assert(node_units.size() >= 4);
+ const auto& node_name = utils::GetUniqueName(*node_units[0]);
+
+ QnnTensorWrapper input_tensor;
+ QnnTensorWrapper output_tensor;
+
+ ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(root_input, input_tensor));
+ ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(final_output, output_tensor));
+
+ if (validate) {
+ ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name,
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
+ QNN_OP_GELU,
+ {input_tensor.GetQnnTensor()},
+ {output_tensor.GetQnnTensor()},
+ {}));
+ } else {
+ // Only add tensor wrappers if they don't already exist
+ if (!qnn_model_wrapper.IsQnnTensorWrapperExist(root_input.node_arg.Name())) {
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input");
+ }
+ if (!qnn_model_wrapper.IsQnnTensorWrapperExist(final_output.node_arg.Name())) {
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output");
+ }
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name,
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
+ QNN_OP_GELU,
+ {root_input.node_arg.Name()},
+ {final_output.node_arg.Name()},
+ {},
+ validate),
+ "Failed to add fused Gelu node.");
+ }
+
+ return Status::OK();
+}
+
+} // namespace qnn
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h
new file mode 100644
index 0000000000000..508b1fca48a67
--- /dev/null
+++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h
@@ -0,0 +1,73 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
+#include "core/providers/qnn/ort_api.h"
+
+namespace onnxruntime {
+namespace qnn {
+
+class QnnModelWrapper;
+
+///
+/// Represents a fusion of the Gelu pattern expanded into ONNX operators.
+/// This fusion handles two patterns:
+/// Pattern 1:
+/// +-------Mul(0.5)---------------------+
+/// | |
+/// | v
+/// [root] --> Div -----> Erf --> Add --> Mul ==>
+/// (B=1.4142...) (1)
+///
+/// Pattern 2:
+/// +------------------------------------+
+/// | |
+/// | v
+/// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==>
+/// (B=1.4142...) (1) (0.5)
+///
+/// Both patterns are translated into a QNN Gelu operator.
+/// The contained NodeUnits can be of type SingleNode or QDQGroup (with Q-DQ nodes).
+/// The second inputs to Div, Add, and Mul Node Units should be constant.
+///
+class GeluFusion : public IQnnNodeGroup {
+ public:
+ GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit);
+ ORT_DISALLOW_COPY_AND_ASSIGNMENT(GeluFusion);
+
+ Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
+ Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
+ gsl::span GetNodeUnits() const override;
+ const NodeUnit* GetTargetNodeUnit() const override;
+ std::string_view Type() const override { return "GeluFusion"; }
+
+ ///
+ /// Traverses graph to check if the given starting NodeUnit is part of a valid Gelu pattern.
+ /// If so, returns a IQnnNodeGroup that contains all the NodeUnits in the pattern.
+ ///
+ /// Used for validation and traverse/query the graph
+ /// Erf node unit that could be part of the sequence
+ /// Maps a Node to a NodeUnit.
+ /// Maps a NodeUnit to a IQnnNodeGroup.
+ ///
+ /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise
+ static std::unique_ptr TryFusion(
+ QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& erf_node_unit,
+ const std::unordered_map& node_to_node_unit,
+ const std::unordered_map& node_unit_to_qnn_node_group,
+ const logging::Logger& logger);
+
+ private:
+ std::vector node_units_;
+ const NodeUnit* target_node_unit_;
+};
+
+} // namespace qnn
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc
index 368caa518b7ba..4297801ce4cdc 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc
@@ -22,6 +22,7 @@
#include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h"
+#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h"
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/qnn/ort_api.h"
@@ -83,6 +84,7 @@ static std::unordered_map> fusions = {
{"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}},
{"Mul", {ScaleSoftmaxFusion::TryFusion}},
{"Cast", {CastLoneQFusion::TryFusion}},
+ {"Erf", {GeluFusion::TryFusion}},
{"Reshape", {Rank6ToRank5Fusion::TryFusion}},
{"Transpose", {ChannelShuffleFusion::TryFusion}}};
@@ -119,9 +121,11 @@ static std::unique_ptr TryQnnFusions(
const std::unordered_map& node_to_node_unit,
const std::unordered_map& node_unit_to_qnn_node_group,
const logging::Logger& logger) {
- // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape
+ // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except
+ // MatMul w/ LPBQ encodings, Erf and Reshape
if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode &&
starting_node_unit.OpType() != "MatMul" &&
+ starting_node_unit.OpType() != "Erf" &&
starting_node_unit.OpType() != "Reshape") {
return nullptr;
}
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc
index 10e1633e4b57d..7b77164a38545 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc
@@ -226,14 +226,92 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer,
return nullptr;
}
- // parent must not already be part of a QDQ NodeUnit (i.e., be standalone).
- if (p_parent_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
+ return p_parent_node_unit;
+ }
+ return nullptr;
+}
+
+const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer,
+ const NodeUnit& node_unit,
+ const NodeUnitIODef& output,
+ const std::unordered_map& node_unit_map,
+ const std::unordered_map& qnn_node_group_map) {
+ const Node* p_parent_node = nullptr;
+
+ for (auto node : node_unit.GetAllNodesInGroup()) {
+ for (auto node_output : node->OutputDefs()) {
+ if (node_output->Name() == output.node_arg.Name()) {
+ p_parent_node = node;
+ break;
+ }
+ }
+ // break the loop if producer node of output is found
+ if (p_parent_node != nullptr) {
+ break;
+ }
+ }
+
+ // return if the given output tensor is not produced by any node in the given node_unit
+ if (p_parent_node == nullptr) {
+ return nullptr;
+ }
+
+ const Node& parent_node = *p_parent_node;
+
+ if (graph_viewer.NodeProducesGraphOutput(parent_node)) {
+ // Node is producing a graph output
+ return nullptr;
+ }
+
+ // First pass: count how many children consume this specific output
+ int child_count = 0;
+ const NodeUnit* p_child_node_unit = nullptr;
+
+ for (auto edge = parent_node.OutputEdgesBegin(); edge != parent_node.OutputEdgesEnd(); ++edge) {
+ const Node& child_node = edge->GetNode();
+
+ // Check if this edge corresponds to the output we're looking for
+ bool is_matching_output = false;
+ for (auto child_input : child_node.InputDefs()) {
+ if (child_input->Name() == output.node_arg.Name()) {
+ is_matching_output = true;
+ break;
+ }
+ }
+
+ if (!is_matching_output) {
+ continue;
+ }
+
+ if (graph_viewer.GetNode(child_node.Index()) == nullptr) {
+ // Node is not in this GraphViewer
return nullptr;
}
- return p_parent_node_unit;
+ const auto child_node_unit_it = node_unit_map.find(&child_node);
+ if (child_node_unit_it == node_unit_map.end()) {
+ return nullptr;
+ }
+ const NodeUnit* current_child_node_unit = child_node_unit_it->second;
+
+ // Check if child node has already been handled. Should not be the case if the calling
+ // fusion function has been called in topological order, but check to be safe.
+ if (qnn_node_group_map.count(current_child_node_unit) != 0) {
+ return nullptr;
+ }
+
+ // Store the child node unit and increment count
+ p_child_node_unit = current_child_node_unit;
+ child_count++;
+
+ // If we found more than one child, return nullptr immediately
+ if (child_count > 1) {
+ return nullptr;
+ }
}
- return nullptr;
+
+ // Return the child only if there's exactly one child
+ return (child_count == 1) ? p_child_node_unit : nullptr;
}
} // namespace qnn
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h
index 14e2a3f25e7db..b52cdd5fa3ec6 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h
@@ -51,5 +51,11 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer,
const std::unordered_map& node_unit_map,
const std::unordered_map& qnn_node_group_map);
+const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer,
+ const NodeUnit& node_unit,
+ const NodeUnitIODef& output,
+ const std::unordered_map& node_unit_map,
+ const std::unordered_map& qnn_node_group_map);
+
} // namespace qnn
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc
index d8e58f0d0a170..ebe71c6ccfacd 100644
--- a/onnxruntime/core/providers/webgpu/compute_context.cc
+++ b/onnxruntime/core/providers/webgpu/compute_context.cc
@@ -20,5 +20,9 @@ const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const Co
return context.ep_.BufferManager();
}
+const SplitKConfig& ComputeContext::GetSplitKConfig() {
+ return webgpu_context_.GetSplitKConfig();
+}
+
} // namespace webgpu
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h
index 01cae1e337439..ed16f2f0a1345 100644
--- a/onnxruntime/core/providers/webgpu/compute_context.h
+++ b/onnxruntime/core/providers/webgpu/compute_context.h
@@ -152,6 +152,13 @@ class ComputeContext final {
return webgpu_context_.Run(*this, program);
}
+ //
+ // Get Split-K configuration.
+ //
+ // `split_k_config_` won't be initialized until the first call to this method.
+ //
+ const SplitKConfig& GetSplitKConfig();
+
private:
WebGpuContext& webgpu_context_;
OpKernelContext& kernel_context_;
diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc
index 7dd3b50c656f4..7cbc7f6a4a821 100644
--- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc
+++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc
@@ -13,7 +13,7 @@ namespace webgpu {
// which are used in the MatMulWriteFnSource function.
namespace {
-void HanldeMaybeHaveBiasForGEMM(ShaderHelper& shader,
+void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader,
const ShaderVariableHelper& output,
bool has_bias,
int c_components,
@@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
<< output.SetByIndices("coords", "value") << "\n";
}
+void HandleMatMulWithSplitK(
+ ShaderHelper& shader,
+ ProgramVariableDataType output_variable_type) {
+ shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";
+
+ // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
+ // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
+ // atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16`
+ // with `atomicLoad` and `atomicCompareExchangeWeak`:
+ // 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`.
+ // 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`).
+ // 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`.
+ // 4. Convert the result of step 3 into `i32` values.
+ // 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak`
+ // and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not
+ // equal to `old_output_i32` (it is updated in another invocation). If the assignment fails
+ // we have to go to step 1 and repeat all the above steps.
+ switch (output_variable_type) {
+ case ProgramVariableDataType::Float32x4: {
+ shader.AdditionalImplementation() << R"(
+ let offset0 = i2o_output(coords) * 4u;
+ for (var i = 0u; i < 4u; i++) {
+ let offset = offset0 + i;
+ while (true) {
+ let old_output_i32 = atomicLoad(&output[offset]);
+ let old_output_f32 = bitcast(old_output_i32);
+ let new_output_f32 = old_output_f32 + value[i];
+ let new_output_i32 = bitcast(new_output_f32);
+ let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
+ if (output_compare_exchange.old_value == old_output_i32) {
+ break;
+ }
+ }
+ }
+)";
+ break;
+ }
+ case ProgramVariableDataType::Float16x4: {
+ shader.AdditionalImplementation() << R"(
+ let offset0 = i2o_output(coords) * 2u;
+ var vec2h_values : array;
+ vec2h_values[0] = value.xy;
+ vec2h_values[1] = value.zw;
+ for (var i = 0u; i < 2u; i++) {
+ let offset = offset0 + i;
+ while (true) {
+ let old_output_i32 = atomicLoad(&output[offset]);
+ let old_output_vec2h = bitcast(old_output_i32);
+ let new_output_vec2h = old_output_vec2h + vec2h_values[i];
+ let new_output_i32 = bitcast(new_output_vec2h);
+ let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
+ if (output_compare_exchange.old_value == old_output_i32) {
+ break;
+ }
+ }
+ }
+)";
+ break;
+ }
+ default:
+ break;
+ }
+}
+
} // namespace
void MatMulReadFnSource(ShaderHelper& shader,
@@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet,
- bool is_channels_last) {
+ bool is_channels_last,
+ bool use_split_k,
+ ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";
@@ -134,8 +200,17 @@ void MatMulWriteFnSource(ShaderHelper& shader,
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
<< " var value = valueIn; \n";
- if (is_gemm) {
- HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
+ if (use_split_k) {
+ // Set output when MatMul is performed with Split-K.
+ // When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
+ // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we
+ // still need to handle `has_bias` (and `is_channels_last` in the future) in
+ // `MatMulFillBiasOrZeroBeforeSplitKProgram`.
+ ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled.");
+ ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled.");
+ HandleMatMulWithSplitK(shader, output_variable_type);
+ } else if (is_gemm) {
+ HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
} else {
HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last);
}
@@ -159,9 +234,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
uint32_t tile_inner,
bool split_k,
uint32_t split_dim_inner) {
- ORT_UNUSED_PARAMETER(split_k);
- ORT_UNUSED_PARAMETER(split_dim_inner);
-
const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type);
std::string write_data_to_sub_a_vec4_snippet =
@@ -208,14 +280,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " let tileCol = i32(local_id.x);\n"
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
- << " let batch = i32(global_id.z);\n"
- << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
- << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
- << " var kStart = 0;\n"
<< " var acc: array, rowPerThread>;\n";
+ if (split_k) {
+ // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
+ // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
+ // `kSplitK * i32(global_id.z)`.
+ //
+ // For example: considering computing Y = (X * W + B) in one workgroup.
+ // Let kSplitK = 2, B = [d1, d2]
+ // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2
+ // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2
+ // [b2 b2] C2 ]
+ // [b2 b2]
+ // [c2 c2]
+ // [c2 c2]]
+ //
+ // With Split-K:
+ // 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2]
+ // [d1, d2]]
+ // 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side)
+ // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
+ // Workgroup3: compute (C1 * C2)
+ // In each workgroup:
+ // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
+ // - When the computation in each workgroup is completed, add the result to Y with several
+ // atomic built-in functions in `HandleMatMulWithSplitK()`.
+ shader.MainFunctionBody()
+ << "const kSplitK = " << split_dim_inner << ";\n"
+ << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
+ << " var kStart = kSplitK * i32(global_id.z);\n"
+
+ // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
+ // the index of split-k instead of batch.
+ << " let batch = 0;\n"
+ << " let batchIndices = 0u;\n";
+ } else {
+ shader.MainFunctionBody()
+ << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
+ << " var kStart = 0;\n"
+ << " let batch = i32(global_id.z);\n"
+ << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
+ }
+
// Loop over shared dimension.
shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n";
diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h
index ed4cf997d2f00..7075debeb9952 100644
--- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h
+++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h
@@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
int output_components,
bool c_is_scalar,
std::string activation_snippet = "",
- bool is_channels_last = false);
+ bool is_channels_last = false,
+ bool use_split_k = false,
+ ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4);
// The two following functions are used to generate shader code for vec4 and scalar.
// It is used in GEMM, Matmul, and Conv.
diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc
index cf4b9d3fae2d2..55c2c5773cc1f 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul.cc
+++ b/onnxruntime/core/providers/webgpu/math/matmul.cc
@@ -161,14 +161,14 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const auto* bias = context.Input(2);
inputs.push_back(bias);
}
- auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);
- return context.RunProgram(program);
+ return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
}
-MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last,
- const TensorShape& input_a_reshape,
- const TensorShape& input_b_reshape) {
+Status ComputeMatMul(ComputeContext* context,
+ const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last,
+ const TensorShape& input_a_reshape,
+ const TensorShape& input_b_reshape) {
const auto* a = inputs[0];
const auto* b = inputs[1];
bool has_bias = inputs.size() > 2;
@@ -226,31 +226,97 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
- const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
- (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
+ uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
+ (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});
- MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
- program
- .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
+ ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components);
+ const Tensor* bias = has_bias ? inputs[2] : nullptr;
+ bool use_bias_in_matmul = has_bias;
+ uint32_t split_dim_inner = 1;
+
+ const SplitKConfig& split_k_config = context->GetSplitKConfig();
+ const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
+ if (need_split_k) {
+ ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
+ ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
+ ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
+
+ // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
+ const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp);
+ ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));
+
+ // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
+ // `bias` again in `MatMulProgram`.
+ use_bias_in_matmul = false;
+
+ // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
+ // number of splits along `dim_inner`.
+ // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize
+ // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`.
+ split_dim_inner = split_k_config.GetSplitDimInner();
+ dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
+
+ // The output should be declared in atomic types in `MatMulProgram` for the use of atomic
+ // built-in functions.
+ output.is_atomic = true;
+ }
+
+ MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
+ matmul_program
+ .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
- .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
- .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);
+ .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
+ .AddOutput(std::move(output));
- if (has_bias) {
+ if (use_bias_in_matmul) {
auto bias_components = is_channels_last ? components : 1;
- const auto* bias = inputs[2];
TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
- program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
+ matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
+ }
+
+ return context->RunProgram(matmul_program);
+}
+
+MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
+ const Tensor* bias,
+ Tensor* output,
+ const TensorShape& output_shape_vec4) {
+ const bool has_bias = bias != nullptr;
+
+ // Currently we only support bias in vec4 and channels last format for Split-K MatMul.
+ constexpr uint32_t bias_components = 4;
+ MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias);
+
+ const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]);
+ const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]);
+
+ // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for
+ // this program.
+ const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4;
+ const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
+
+ // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar
+ // instead of vec4, while use `output_shape_vec4` directly as the output shape.
+ const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components);
+ program.CacheHint(has_bias)
+ .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)})
+ .AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
+ .SetDispatchGroupSize(dispatch_x);
+
+ if (has_bias) {
+ const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
+ program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)});
}
+
return program;
}
diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h
index 8ab8c3a6ba2d0..0b65827be7f17 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul.h
+++ b/onnxruntime/core/providers/webgpu/math/matmul.h
@@ -14,9 +14,14 @@
namespace onnxruntime {
namespace webgpu {
-MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last,
- const TensorShape& input_a_reshape = TensorShape(),
- const TensorShape& input_b_reshape = TensorShape());
+Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last,
+ const TensorShape& input_a_reshape = TensorShape(),
+ const TensorShape& input_b_reshape = TensorShape());
+
+MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
+ const Tensor* bias,
+ Tensor* output,
+ const TensorShape& output_shape_vec4);
class MatMul final : public WebGpuKernel {
public:
diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
index 585f8f1e011c4..4daabe8246aa7 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
+++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
@@ -14,25 +14,77 @@ namespace webgpu {
Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
- const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+
+ const bool need_split_k = NeedSplitK();
+ ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias;
+ if (need_split_k) {
+ // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in
+ // functions on it, so we need below information to correctly compute the index on the output.
+ output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride;
+ }
+ const auto& output = shader.AddOutput("output", output_usage);
+
const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
+ ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
// declare the read and write functions
MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_);
- MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_);
+ MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type);
std::string data_type = "a_element_t";
// generate the main function
if (is_vec4_) {
- ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims));
+ ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(
+ shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims,
+ /*transA = */ false, /*transB = */ false, /*alpha = */ 1.f, /*need_handle_matmul = */ true,
+ /*output_components = */ 4, /*tile_inner = */ 32, need_split_k, split_dim_inner_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims));
}
return Status::OK();
}
+bool MatMulProgram::NeedSplitK() const {
+ return split_dim_inner_ > 1;
+}
+
+Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+
+ if (has_bias_) {
+ shader.AddInput("bias", ShaderUsage::UseUniform);
+ }
+
+ // Handle bias with `MatMulWriteFnSource()`.
+ // Here `use_split_k` is false because we just initialize `output` with bias.
+ // `use_split_k` is true only when we do the actual MatMul with Split-K.
+ // Currently we only support bias in vec4 and channels last format for Split-K MatMul.
+ MatMulWriteFnSource(
+ shader, output, has_bias_, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false,
+ /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false);
+
+ shader.MainFunctionBody() << R"(
+ let output_components = 4;
+ let output_id = i32(global_idx);
+
+ let dim_a_outer = i32(uniforms.dim_a_outer);
+ let dim_b_outer = i32(uniforms.dim_b_outer) / output_components;
+ if (output_id >= dim_a_outer * dim_b_outer) {
+ return;
+ }
+
+ let output_row = output_id / dim_b_outer;
+ let output_col = output_id % dim_b_outer;
+ let output_batch = 0;
+ let output_value = output_value_t();
+ mm_write(output_batch, output_row, output_col, output_value);
+)";
+
+ return Status::OK();
+}
+
} // namespace webgpu
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h
index 767fdd8802e5b..143ba61c99e13 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h
+++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h
@@ -13,24 +13,48 @@ namespace onnxruntime {
namespace webgpu {
class MatMulProgram final : public Program {
public:
- MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"},
- activation_(activation),
- has_bias_{bias},
- is_vec4_{is_vec4},
- elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()),
- is_channels_last_(is_channels_last) {}
+ MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"},
+ activation_(activation),
+ has_bias_{bias},
+ is_vec4_{is_vec4},
+ elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()),
+ is_channels_last_(is_channels_last),
+ split_dim_inner_(split_dim_inner) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"dim_inner", ProgramUniformVariableDataType::Uint32});
+ bool NeedSplitK() const;
+
private:
const Activation activation_;
const bool has_bias_;
const bool is_vec4_;
const InlinedVector elements_per_thread_;
bool is_channels_last_ = false;
+ uint32_t split_dim_inner_ = 1;
+};
+
+// The program to initialize the output with 0 or bias before doing MatMul with Split-K. In Split-K,
+// we set the output values with `atomicLoad` and `atomicCompareExchangeWeak` instead of a direct
+// assignment (see the function `HandleMatMulWithSplitK()` in `gemm_utils.cc`), so we must initialize
+// the output with 0 or bias first to make sure `atomicLoad` won't return garbage data.
+class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program {
+ public:
+ explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias)
+ : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"},
+ has_bias_(has_bias) {
+ }
+
+ Status GenerateShaderCode(ShaderHelper& sh) const override;
+
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
+ {"dim_b_outer", ProgramUniformVariableDataType::Uint32});
+
+ private:
+ bool has_bias_ = false;
};
} // namespace webgpu
diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc
index a2777979ae983..77fa46cb87518 100644
--- a/onnxruntime/core/providers/webgpu/nn/conv.cc
+++ b/onnxruntime/core/providers/webgpu/nn/conv.cc
@@ -200,8 +200,7 @@ Status Conv::ComputeInternal(ComputeContext& context
.AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}});
return context.RunProgram(program);
} else {
- MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]);
- return context.RunProgram(program);
+ return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]);
}
}
// Transpose weights
diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc
index 9e934e9eb5db7..aa0caab39a88e 100644
--- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc
+++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/webgpu/nn/fuse_utils.h"
+#include "core/framework/op_kernel_info.h"
#include
namespace onnxruntime {
namespace webgpu {
diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h
index f5d2585bb9b45..fad7d3d145bc6 100644
--- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h
+++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h
@@ -1,11 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include
-#include "core/providers/webgpu/webgpu_kernel.h"
+#include
+
+#include "core/common/status.h"
#pragma once
namespace onnxruntime {
+
+class OpKernelInfo;
+
namespace webgpu {
+
enum class ActivationKind {
None,
Relu,
diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc
index 5447966b91aa7..b08649cbd5d5b 100644
--- a/onnxruntime/core/providers/webgpu/shader_helper.cc
+++ b/onnxruntime/core/providers/webgpu/shader_helper.cc
@@ -472,7 +472,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha
}
ss << ": array<";
if (is_atomic) {
- if (output->type_ == ProgramVariableDataType::Float32) {
+ if (output->type_ == ProgramVariableDataType::Float32 || output->type_ == ProgramVariableDataType::Float16x4 || output->type_ == ProgramVariableDataType::Float32x4) {
ss << "atomic"; // emulate float atomic via i32
} else if (output->type_ == ProgramVariableDataType::Uint32) {
ss << "atomic";
diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc
index 9af9cd455b5a4..28decb076951e 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_context.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc
@@ -910,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h
index 1ead7b3a005bb..bd7dae75f2e2d 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_context.h
+++ b/onnxruntime/core/providers/webgpu/webgpu_context.h
@@ -5,12 +5,14 @@
#include
#include
+#include
#include "core/providers/webgpu/webgpu_external_header.h"
#include "core/common/common.h"
#include "core/providers/webgpu/buffer_manager.h"
#include "core/providers/webgpu/program_manager.h"
+#include "core/providers/webgpu/webgpu_utils.h"
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
@@ -171,6 +173,13 @@ class WebGpuContext final {
Status Run(ComputeContext& context, const ProgramBase& program);
void OnRunEnd();
+ //
+ // Get Split-K configuration.
+ //
+ // `split_k_config_` won't be initialized until the first call to this method.
+ //
+ const SplitKConfig& GetSplitKConfig();
+
private:
enum class TimestampQueryType {
None = 0,
@@ -268,6 +277,8 @@ class WebGpuContext final {
uint32_t num_pending_dispatches_ = 0;
const uint32_t max_num_pending_dispatches_ = 16;
+ std::optional split_k_config_;
+
// profiling
TimestampQueryType query_type_;
wgpu::QuerySet query_set_;
diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
index 3df194217933e..e0b84fef51f1f 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
@@ -878,9 +878,9 @@ std::vector> WebGpuExecutionProvider::GetCapa
const auto& inputs = node.InputDefs();
const auto& outputs = node.OutputDefs();
- // Current implementation does not support mask_index(input[3]), past(input[5]) and past_seq_len(input[6])
+ // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6])
FALLBACK_TO_CPU_IF_EXIST_INPUT(3);
- FALLBACK_TO_CPU_IF_EXIST_INPUT(5);
+ FALLBACK_TO_CPU_IF_EXIST_INPUT(4);
FALLBACK_TO_CPU_IF_EXIST_INPUT(6);
// Current implementation does not support present(output[1])
diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc
index 53b96dfe7a346..568d29a96cb88 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc
@@ -21,5 +21,64 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components
return TensorShape(shape_vector);
}
+SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) {
+ SplitKConfig config = {};
+
+ if (adapter_info.vendor == std::string_view{"intel"}) {
+ if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
+ adapter_info.architecture == std::string_view{"xe-2hpg"} ||
+ adapter_info.architecture == std::string_view{"xe-lpg"} ||
+ adapter_info.architecture == std::string_view{"gen-12hp"}) {
+ config.enable_split_k_ = true;
+
+ // Below thresholds are only verified on the above Intel GPUs without any regressions. The
+ // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be
+ // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more
+ // atomic calls for each output value.
+ config.split_dim_inner_ = 256;
+ config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2;
+ config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9;
+ config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f;
+ }
+ }
+ return config;
+}
+
+bool SplitKConfig::UseSplitK(
+ bool is_vec4,
+ ActivationKind activation_kind,
+ uint64_t batch_size,
+ bool is_channels_last,
+ uint32_t dim_a_outer,
+ uint32_t dim_b_outer,
+ uint32_t dim_inner) const {
+ if (!enable_split_k_) {
+ return false;
+ }
+
+ bool use_split_k = true;
+
+ // TODO: support the cases below.
+ use_split_k &= activation_kind == ActivationKind::None;
+ use_split_k &= is_vec4;
+ use_split_k &= batch_size == 1;
+ // Now `is_channels_last` is only supported because we only generate vec4 shaders in
+ // `MatMulFillBiasOrZeroBeforeSplitKProgram`.
+ use_split_k &= is_channels_last;
+
+ // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
+ // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
+ // `dim_inner)` as the metric to decide whether to use Split-K or not.
+ use_split_k &= (dim_inner >= min_dim_inner_with_split_k_);
+ use_split_k &= (dim_inner <= max_dim_inner_with_split_k_);
+ use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_);
+
+ return use_split_k;
+}
+
+uint32_t SplitKConfig::GetSplitDimInner() const {
+ return split_dim_inner_;
+}
+
} // namespace webgpu
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h
index 86eb57f99f3b3..d45b9bf4dd119 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_utils.h
+++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h
@@ -7,6 +7,8 @@
#include "core/common/common.h"
#include "core/framework/tensor.h"
#include "core/framework/tensor_shape.h"
+#include "core/providers/webgpu/webgpu_external_header.h"
+#include "core/providers/webgpu/nn/fuse_utils.h"
namespace onnxruntime {
namespace webgpu {
@@ -89,5 +91,24 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c
return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()};
}
+class SplitKConfig {
+ public:
+ static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info);
+
+ bool UseSplitK(
+ bool is_vec4, ActivationKind activation_kind, uint64_t batch_size,
+ bool is_channels_last, uint32_t dim_a_outer,
+ uint32_t dim_b_outer, uint32_t dim_inner) const;
+
+ uint32_t GetSplitDimInner() const;
+
+ private:
+ bool enable_split_k_ = false;
+ uint32_t split_dim_inner_ = 0;
+ uint32_t min_dim_inner_with_split_k_ = 0;
+ uint32_t max_dim_inner_with_split_k_ = 0;
+ float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f;
+};
+
} // namespace webgpu
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc
index 0b927075402fe..a29fbdb91e79f 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc
@@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape");
NodeAttrHelper helper(node);
+ const int32_t local_window_size = helper.Get("local_window_size", -1);
const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0);
const uint32_t num_heads = helper.Get("num_heads", 0);
@@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
| |
+-------------------------------> Lesser <---------------------Transpose (1,0)
|
- 1 ---> Where <--- finfo_min (minimum value of FP32)
+ 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32)
|
attention_bias
*/
- const std::vector mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length,
- 1);
- std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) +
- "_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) +
- "_" + std::to_string(past_sequence_length);
- emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant(
- ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape,
- std::vector({batch_size, num_heads, qkv_sequence_length, past_sequence_length}));
+ emscripten::val value_int_one_constant =
+ model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1});
+
+ std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length};
+ common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand");
+ emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call(
+ "expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options);
emscripten::val cumsum_options = emscripten::val::object();
cumsum_options.set("label", node.Name() + "_range_of_mask_shape");
@@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1);
std::string pre_neq_right_data_range_name =
- "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length);
+ "webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length);
emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant(
ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range,
std::vector({qkv_sequence_length}));
@@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b
emscripten::val neq_right =
model_builder.GetBuilder().call("transpose", expanded_neq_right, transpose_options);
- common_options.set("label", node.Name() + "_/GQA/attn_mask/condition");
- emscripten::val condition =
+ common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1");
+ emscripten::val condition_1 =
model_builder.GetBuilder().call("lesser", neq_left, neq_right, common_options);
+ emscripten::val condition = condition_1;
+ // For local window size not equal to -1, new attention mask pattern for applying sliding window
+ /*
+ condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true)
+ | |
+ | Lesser <--- local_window_size
+ | |
+ LogicalAnd <----------------- condition_2
+ |
+ new attn_mask
+ */
+ if (local_window_size != -1) {
+ // Cast condition
+ common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast");
+ emscripten::val casted_condition_1 =
+ model_builder.GetBuilder().call