diff --git a/cmake/deps.txt b/cmake/deps.txt index e1870bf2df0cf..078a66a4c4d85 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -31,7 +31,7 @@ googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f6 googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 -microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 +microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.250325.1.zip;826c8bd47c2258ec61b8b218e031e5b33d27f761 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.19.1.zip;c5215b5697dcdfd71799f001b8c4054a6bba6b09 diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 754669fffbf8d..4913d38939792 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1528,8 +1528,14 @@ endif() onnxruntime_add_shared_library(onnxruntime_runtime_path_test_shared_library ${onnxruntime_runtime_path_test_shared_library_src}) - target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE - onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + if (CMAKE_SYSTEM_NAME MATCHES "AIX") + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common ${CMAKE_DL_LIBS}) + set_target_properties(onnxruntime_runtime_path_test_shared_library PROPERTIES AIX_SHARED_LIBRARY_ARCHIVE OFF) + else() + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + endif() target_include_directories(onnxruntime_runtime_path_test_shared_library PRIVATE ${ONNXRUNTIME_ROOT}) if(UNIX) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 92093ec5464f7..d7be243323c17 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -88,7 +88,7 @@ Do not modify directly.* |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[11, 21]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| -|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| +|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| |ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[11, 21]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 24cc460a17fa9..983be1f9efd5c 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -38,6 +38,13 @@ struct OrtArenaCfg { int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default + // Use CudaMemPool based arena if available (starting with cuda 11.2) + int use_cuda_mempool = -1; + // Amount of reserved memory in bytes to hold onto before trying + // to release memory back to the OS. + uint64_t cuda_mempool_release_threshold = 0; + // Bytes to keep on shrink for CudaMemPool, 0 is to attempt to release all, allocated space not affected. + size_t cuda_mempool_bytes_to_keep_on_shrink = 0; bool IsValid() { return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 && @@ -55,6 +62,9 @@ struct OrtArenaCfg { static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes"; static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes"; static constexpr const char* MaxMem = "arena.max_mem"; + static constexpr const char* UseCudaMemPool = "arena.use_cuda_mempool"; + static constexpr const char* CudaMempoolReleaseThreshold = "arena.cuda_mempool_release_threshold"; + static constexpr const char* CudaMempoolBytesToKeepOnShrink = "arena.cuda_mempool_bytes_to_keep_on_shrink"; }; static onnxruntime::common::Status FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg); @@ -348,4 +358,13 @@ void AllocatorDefaultFree(void* p); void* AllocatorDefaultAllocAligned(size_t size, size_t alignment); void AllocatorDefaultFreeAligned(void* p, size_t alignment); +class IArena : public IAllocator { + public: + using IAllocator::IAllocator; + virtual Status Shrink() = 0; + // Only implemented when IsStreamAware() returns true + virtual void ReleaseStreamBuffers(Stream* /*stream*/) {} + static IArena* SafeArenaCast(IAllocator* allocator); +}; + } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 02915f2f1882e..d1b652229e4b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6591,6 +6591,23 @@ struct OrtApi { * \since Version 1.24 */ ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** \brief Get all config entries from ::OrtKernelInfo. + * + * Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs. + * Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels. + * + * Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries + * during kernel construction. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries. + * Note: the user should call OrtApi::ReleaseKeyValuePairs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d3a8856455c49..22708bbf06a3d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base { std::string GetNodeName() const; Logger GetLogger() const; + + KeyValuePairs GetConfigEntries() const; }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8ee057f51eb20..5144418db2b58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl::GetLogger() const { return Logger{out}; } +template +inline KeyValuePairs KernelInfoImpl::GetConfigEntries() const { + OrtKeyValuePairs* out = nullptr; + Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out)); + return KeyValuePairs{out}; +} + inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); } diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index f6afd84dabc5e..5ea4261840299 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -9,6 +9,9 @@ // Key for the execution provider version string. This should be available for all plugin EPs. static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; +// Key for the execution provider OS driver version. +static const char* const kOrtEpDevice_EpMetadataKey_OSDriverVersion = "os_driver_version"; + // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index f4ccff1d7770d..9f2f0afa61604 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1533,7 +1533,7 @@ // // "test_adam_multiple", // // "test_adam", "test_add_bcast", - // "test_add_uint8", + "test_add_uint8", "test_add", "test_and_bcast3v1d", "test_and_bcast3v2d", @@ -1543,37 +1543,38 @@ "test_and2d", "test_and3d", "test_and4d", - "test_argmax_default_axis_example_select_last_index", + // tests "test_arg*_select_last_index" are excluded because WebNN spec does not support select_last_index attribute. + // "test_argmax_default_axis_example_select_last_index", "test_argmax_default_axis_example", - "test_argmax_default_axis_random_select_last_index", + // "test_argmax_default_axis_random_select_last_index", "test_argmax_default_axis_random", - "test_argmax_keepdims_example_select_last_index", + // "test_argmax_keepdims_example_select_last_index", "test_argmax_keepdims_example", - "test_argmax_keepdims_random_select_last_index", + // "test_argmax_keepdims_random_select_last_index", "test_argmax_keepdims_random", - "test_argmax_negative_axis_keepdims_example_select_last_index", + // "test_argmax_negative_axis_keepdims_example_select_last_index", "test_argmax_negative_axis_keepdims_example", - "test_argmax_negative_axis_keepdims_random_select_last_index", + // "test_argmax_negative_axis_keepdims_random_select_last_index", "test_argmax_negative_axis_keepdims_random", - "test_argmax_no_keepdims_example_select_last_index", + // "test_argmax_no_keepdims_example_select_last_index", "test_argmax_no_keepdims_example", - "test_argmax_no_keepdims_random_select_last_index", + // "test_argmax_no_keepdims_random_select_last_index", "test_argmax_no_keepdims_random", - "test_argmin_default_axis_example_select_last_index", + // "test_argmin_default_axis_example_select_last_index", "test_argmin_default_axis_example", - "test_argmin_default_axis_random_select_last_index", + // "test_argmin_default_axis_random_select_last_index", "test_argmin_default_axis_random", - "test_argmin_keepdims_example_select_last_index", + // "test_argmin_keepdims_example_select_last_index", "test_argmin_keepdims_example", - "test_argmin_keepdims_random_select_last_index", + // "test_argmin_keepdims_random_select_last_index", "test_argmin_keepdims_random", - "test_argmin_negative_axis_keepdims_example_select_last_index", + // "test_argmin_negative_axis_keepdims_example_select_last_index", "test_argmin_negative_axis_keepdims_example", - "test_argmin_negative_axis_keepdims_random_select_last_index", + // "test_argmin_negative_axis_keepdims_random_select_last_index", "test_argmin_negative_axis_keepdims_random", - "test_argmin_no_keepdims_example_select_last_index", + // "test_argmin_no_keepdims_example_select_last_index", "test_argmin_no_keepdims_example", - "test_argmin_no_keepdims_random_select_last_index", + // "test_argmin_no_keepdims_random_select_last_index", "test_argmin_no_keepdims_random", // "test_asin_example", // "test_asin", @@ -1587,21 +1588,21 @@ // "test_averagepool_2d_ceil", "test_averagepool_2d_default", "test_averagepool_2d_pads_count_include_pad", - "test_averagepool_2d_pads", + // "test_averagepool_2d_pads", // unsupported by TFLite backend. "test_averagepool_2d_precomputed_pads_count_include_pad", "test_averagepool_2d_precomputed_pads", "test_averagepool_2d_precomputed_same_upper", "test_averagepool_2d_precomputed_strides", - "test_averagepool_2d_same_lower", + // "test_averagepool_2d_same_lower", // unsupported by TFLite backend. "test_averagepool_2d_same_upper", "test_averagepool_2d_strides", // "test_averagepool_3d_default", "test_basic_conv_with_padding", "test_basic_conv_without_padding", "test_basic_convinteger", - "test_batchnorm_epsilon_training_mode", + // "test_batchnorm_epsilon_training_mode", // unsupported training_mode by WebNN. "test_batchnorm_epsilon", - "test_batchnorm_example_training_mode", + // "test_batchnorm_example_training_mode", // unsupported training_mode by WebNN. "test_batchnorm_example", // // "test_bernoulli_double_expanded", // // "test_bernoulli_double", @@ -1622,10 +1623,10 @@ // // "test_blackmanwindow_symmetric", // // "test_blackmanwindow", // // "test_cast_BFLOAT16_to_FLOAT", - "test_cast_DOUBLE_to_FLOAT", + // "test_cast_DOUBLE_to_FLOAT", // "test_cast_DOUBLE_to_FLOAT16", // // "test_cast_FLOAT_to_BFLOAT16", - "test_cast_FLOAT_to_DOUBLE", + // "test_cast_FLOAT_to_DOUBLE", // // "test_cast_FLOAT_to_FLOAT16", // // "test_cast_FLOAT_to_STRING", // "test_cast_FLOAT16_to_DOUBLE", @@ -1657,15 +1658,16 @@ // "test_celu", "test_clip_default_inbounds", "test_clip_default_int8_inbounds", - "test_clip_default_int8_max", - "test_clip_default_int8_min", - "test_clip_default_max", - "test_clip_default_min", - "test_clip_example", - "test_clip_inbounds", - "test_clip_outbounds", - "test_clip_splitbounds", - "test_clip", + // "test_clip_default_int8_max", + // "test_clip_default_int8_min", + // tests "test_clip*" on opset > 10 are excluded because max and min are non-constant inputs. + "opset{7,8,9,10}/test_clip_default_max", + "opset{7,8,9,10}/test_clip_default_min", + "opset{7,8,9,10}/test_clip_example", + "opset{7,8,9,10}/test_clip_inbounds", + "opset{7,8,9,10}/test_clip_outbounds", + "opset{7,8,9,10}/test_clip_splitbounds", + "opset{7,8,9,10}/test_clip", // // "test_compress_0", // // "test_compress_1", // // "test_compress_default_axis", @@ -1690,32 +1692,33 @@ "test_convinteger_without_padding", "test_convtranspose_1d", // // "test_convtranspose_3d", - // "test_convtranspose_autopad_same", - "test_convtranspose_dilations", + "!(opset14)/test_convtranspose_autopad_same", + // "test_convtranspose_dilations", // unsupported by TFLite backend. "test_convtranspose_kernel_shape", "opset{9,17}/test_convtranspose_output_shape", "test_convtranspose_pad", - "test_convtranspose_pads", + // "test_convtranspose_pads", // unsupported by TFLite backend. "test_convtranspose_with_kernel", "test_convtranspose", "test_cos_example", "test_cos", // "test_cosh_example", // "test_cosh", - "test_cumsum_1d_exclusive", - "test_cumsum_1d_reverse_exclusive", - "test_cumsum_1d_reverse", - "test_cumsum_1d", - "test_cumsum_2d_axis_0", - "test_cumsum_2d_axis_1", - "test_cumsum_2d_negative_axis", + // tests "test_cumsum*" are excluded because they use float64. + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", // "test_depthtospace_crd_mode_example", // "test_depthtospace_crd_mode", // "test_depthtospace_dcr_mode", // "test_depthtospace_example", // "test_depthtospace", - // // "test_dequantizelinear_axis", - // // "test_dequantizelinear", + "test_dequantizelinear_axis", + "test_dequantizelinear", // // "test_det_2d", // // "test_det_nd", // // "test_dft_axis", @@ -1723,27 +1726,27 @@ // // "test_dft", "test_div_bcast", "test_div_example", - // "test_div_uint8", + "test_div_uint8", "test_div", - // // "test_dropout_default_mask_ratio", - // // "test_dropout_default_mask", - // // "test_dropout_default_old", - // // "test_dropout_default_ratio", - // // "test_dropout_default", - // // "test_dropout_random_old", - // // "test_dropout_random", + "test_dropout_default_mask_ratio", + "test_dropout_default_mask", + "test_dropout_default_old", + "test_dropout_default_ratio", + "test_dropout_default", + "test_dropout_random_old", + "test_dropout_random", // // "test_dynamic_slice_default_axes", // // "test_dynamic_slice_end_out_of_bounds", // // "test_dynamic_slice_neg", // // "test_dynamic_slice_start_out_of_bounds", // // "test_dynamic_slice", - // // "test_dynamicquantizelinear_expanded", - // // "test_dynamicquantizelinear_max_adjusted_expanded", - // // "test_dynamicquantizelinear_max_adjusted", - // // "test_dynamicquantizelinear_min_adjusted_expanded", - // // "test_dynamicquantizelinear_min_adjusted", - // // "test_dynamicquantizelinear", - // "test_edge_pad", + "test_dynamicquantizelinear_expanded", + "test_dynamicquantizelinear_max_adjusted_expanded", + "test_dynamicquantizelinear_max_adjusted", + "test_dynamicquantizelinear_min_adjusted_expanded", + "test_dynamicquantizelinear_min_adjusted", + "test_dynamicquantizelinear", + // "opset{7,8,9,10}/test_edge_pad", // The edge padding model is unsupported by TFLite backend. // "test_einsum_batch_diagonal", // "test_einsum_batch_matmul", // "test_einsum_inner_prod", @@ -1754,7 +1757,7 @@ "test_elu", "test_equal_bcast", "test_equal", - // "test_erf", + "test_erf", "test_exp_example", "test_exp", // "test_expand_dim_changed", @@ -1777,11 +1780,11 @@ "test_gather_1", "test_gather_2d_indices", "test_gather_negative_indices", - "test_gather_elements_0", - "test_gather_elements_1", - "test_gather_elements_negative_indices", + // "test_gather_elements_0", // TFLite backend only supports constant indices. + // "test_gather_elements_1", // TFLite backend only supports constant indices. + // "test_gather_elements_negative_indices", // TFLite backend only supports constant indices. "test_gathernd_example_float32", - "test_gathernd_example_int32_batch_dim1", + // "test_gathernd_example_int32_batch_dim1", "test_gathernd_example_int32", "test_gemm_all_attributes", "test_gemm_alpha", @@ -1789,7 +1792,7 @@ "test_gemm_broadcast", "test_gemm_default_matrix_bias", "test_gemm_default_no_bias", - // "test_gemm_default_scalar_bias", + "test_gemm_default_scalar_bias", "test_gemm_default_single_elem_vector_bias", "test_gemm_default_vector_bias", "test_gemm_default_zero_bias", @@ -1845,48 +1848,49 @@ // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", - // "test_isinf_negative", - // "test_isinf_positive", - // "test_isinf", - // "test_isnan", + "test_isinf_negative", + "test_isinf_positive", + "test_isinf", + "test_isnan", + // tests "test_layernorm*" are excluded because they produce 3 outputs. // "test_layer_normalization_2d_axis_negative_1_expanded", - "test_layer_normalization_2d_axis_negative_1", + // "test_layer_normalization_2d_axis_negative_1", // "test_layer_normalization_2d_axis_negative_2_expanded", - "test_layer_normalization_2d_axis_negative_2", + // "test_layer_normalization_2d_axis_negative_2", // "test_layer_normalization_2d_axis0_expanded", - "test_layer_normalization_2d_axis0", + // "test_layer_normalization_2d_axis0", // "test_layer_normalization_2d_axis1_expanded", - "test_layer_normalization_2d_axis1", + // "test_layer_normalization_2d_axis1", // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_1_epsilon", + // "test_layer_normalization_3d_axis_negative_1_epsilon", // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_2_epsilon", + // "test_layer_normalization_3d_axis_negative_2_epsilon", // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_3_epsilon", + // "test_layer_normalization_3d_axis_negative_3_epsilon", // "test_layer_normalization_3d_axis0_epsilon_expanded", - "test_layer_normalization_3d_axis0_epsilon", + // "test_layer_normalization_3d_axis0_epsilon", // "test_layer_normalization_3d_axis1_epsilon_expanded", - "test_layer_normalization_3d_axis1_epsilon", + // "test_layer_normalization_3d_axis1_epsilon", // "test_layer_normalization_3d_axis2_epsilon_expanded", - "test_layer_normalization_3d_axis2_epsilon", + // "test_layer_normalization_3d_axis2_epsilon", // "test_layer_normalization_4d_axis_negative_1_expanded", - "test_layer_normalization_4d_axis_negative_1", + // "test_layer_normalization_4d_axis_negative_1", // "test_layer_normalization_4d_axis_negative_2_expanded", - "test_layer_normalization_4d_axis_negative_2", + // "test_layer_normalization_4d_axis_negative_2", // "test_layer_normalization_4d_axis_negative_3_expanded", - "test_layer_normalization_4d_axis_negative_3", + // "test_layer_normalization_4d_axis_negative_3", // "test_layer_normalization_4d_axis_negative_4_expanded", - "test_layer_normalization_4d_axis_negative_4", + // "test_layer_normalization_4d_axis_negative_4", // "test_layer_normalization_4d_axis0_expanded", - "test_layer_normalization_4d_axis0", + // "test_layer_normalization_4d_axis0", // "test_layer_normalization_4d_axis1_expanded", - "test_layer_normalization_4d_axis1", + // "test_layer_normalization_4d_axis1", // "test_layer_normalization_4d_axis2_expanded", - "test_layer_normalization_4d_axis2", + // "test_layer_normalization_4d_axis2", // "test_layer_normalization_4d_axis3_expanded", - "test_layer_normalization_4d_axis3", + // "test_layer_normalization_4d_axis3", // "test_layer_normalization_default_axis_expanded", - "test_layer_normalization_default_axis", + // "test_layer_normalization_default_axis", "test_leakyrelu_default", "test_leakyrelu_example", "test_leakyrelu", @@ -1912,42 +1916,42 @@ // // "test_logsoftmax_large_number", // // "test_logsoftmax_negative_axis_expanded", // // "test_logsoftmax_negative_axis", - // "test_lrn_default", - // "test_lrn", + "test_lrn_default", + "test_lrn", // // "test_lstm_batchwise", "test_lstm_defaults", "test_lstm_with_initial_bias", - "test_lstm_with_peepholes", + // "test_lstm_with_peepholes", "test_matmul_2d", "test_matmul_3d", "test_matmul_4d", - // // "test_matmulinteger", + "test_matmulinteger", "test_max_example", // "test_max_float16", "test_max_float32", - "test_max_float64", + // "test_max_float64", // "test_max_int16", - // "test_max_int32", - // "test_max_int64", - // "test_max_int8", + "test_max_int32", + "test_max_int64", + "test_max_int8", "test_max_one_input", "test_max_two_inputs", // "test_max_uint16", - // "test_max_uint32", + "test_max_uint32", // "test_max_uint64", - // "test_max_uint8", + "test_max_uint8", // "test_maxpool_1d_default", // "test_maxpool_2d_ceil", "test_maxpool_2d_default", - "test_maxpool_2d_dilations", - "test_maxpool_2d_pads", + // "test_maxpool_2d_dilations", // unsupported by TFLite backend. + // "test_maxpool_2d_pads", // unsupported by TFLite backend. "test_maxpool_2d_precomputed_pads", "test_maxpool_2d_precomputed_same_upper", "test_maxpool_2d_precomputed_strides", - "test_maxpool_2d_same_lower", + // "test_maxpool_2d_same_lower", // unsupported by TFLite backend. "test_maxpool_2d_same_upper", "test_maxpool_2d_strides", - // "test_maxpool_2d_uint8", + "test_maxpool_2d_uint8", // "test_maxpool_3d_default", // "test_maxpool_with_argmax_2d_precomputed_pads", // "test_maxpool_with_argmax_2d_precomputed_strides", @@ -1960,17 +1964,17 @@ "test_min_example", // "test_min_float16", "test_min_float32", - "test_min_float64", + // "test_min_float64", // "test_min_int16", - // "test_min_int32", - // "test_min_int64", - // "test_min_int8", + "test_min_int32", + "test_min_int64", + "test_min_int8", "test_min_one_input", "test_min_two_inputs", // "test_min_uint16", - // "test_min_uint32", + "test_min_uint32", // "test_min_uint64", - // "test_min_uint8", + "test_min_uint8", // "test_mod_bcast", // "test_mod_broadcast", // "test_mod_float_mixed_sign_example", @@ -1992,9 +1996,9 @@ // // "test_momentum", "test_mul_bcast", "test_mul_example", - // "test_mul_uint8", + "test_mul_uint8", "test_mul", - // "test_mvn_expanded", + "test_mvn_expanded", // "test_mvn", "test_neg_example", "test_neg", @@ -2110,17 +2114,17 @@ // "test_pow_types_float32_uint64", // "test_pow_types_int", // "test_pow_types_int32_float32", - // "test_pow_types_int32_int32", + "test_pow_types_int32_int32", // "test_pow_types_int64_float32", - // "test_pow_types_int64_int64", + "test_pow_types_int64_int64", "test_pow", "test_prelu_broadcast", "test_prelu_example", // // "test_qlinearconv", // // "test_qlinearmatmul_2D", // // "test_qlinearmatmul_3D", - // // "test_quantizelinear_axis", - // // "test_quantizelinear", + "test_quantizelinear_axis", + "test_quantizelinear", // "test_range_float_type_positive_delta_expanded", // "test_range_float_type_positive_delta", // "test_range_int32_type_negative_delta_expanded", @@ -2129,23 +2133,24 @@ "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", "test_reduce_l1_default_axes_keepdims_random", - "test_reduce_l1_do_not_keepdims_example", - "test_reduce_l1_do_not_keepdims_random", - "test_reduce_l1_keep_dims_example", - "test_reduce_l1_keep_dims_random", - "test_reduce_l1_negative_axes_keep_dims_example", - "test_reduce_l1_negative_axes_keep_dims_random", + // tests "test_reduce_*" on opset > 13 are excluded because the axes is non-constant input. + "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_random", "test_reduce_l2_default_axes_keepdims_example", "test_reduce_l2_default_axes_keepdims_random", - "test_reduce_l2_do_not_keepdims_example", - "test_reduce_l2_do_not_keepdims_random", - "test_reduce_l2_keep_dims_example", - "test_reduce_l2_keep_dims_random", - "test_reduce_l2_negative_axes_keep_dims_example", - "test_reduce_l2_negative_axes_keep_dims_random", - "test_reduce_log_sum_asc_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_asc_axes", "test_reduce_log_sum_default", - "test_reduce_log_sum_desc_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_desc_axes", // tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64. "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example", "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random", @@ -2155,116 +2160,118 @@ "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random", "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example", "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random", - "test_reduce_log_sum_negative_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_negative_axes", "test_reduce_log_sum", "test_reduce_max_default_axes_keepdim_example", "test_reduce_max_default_axes_keepdims_random", - "test_reduce_max_do_not_keepdims_example", - "test_reduce_max_do_not_keepdims_random", - "test_reduce_max_keepdims_example", - "test_reduce_max_keepdims_random", - "test_reduce_max_negative_axes_keepdims_example", - "test_reduce_max_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_random", "test_reduce_mean_default_axes_keepdims_example", "test_reduce_mean_default_axes_keepdims_random", - "test_reduce_mean_do_not_keepdims_example", - "test_reduce_mean_do_not_keepdims_random", - "test_reduce_mean_keepdims_example", - "test_reduce_mean_keepdims_random", - "test_reduce_mean_negative_axes_keepdims_example", - "test_reduce_mean_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_random", "test_reduce_min_default_axes_keepdims_example", "test_reduce_min_default_axes_keepdims_random", - "test_reduce_min_do_not_keepdims_example", - "test_reduce_min_do_not_keepdims_random", - "test_reduce_min_keepdims_example", - "test_reduce_min_keepdims_random", - "test_reduce_min_negative_axes_keepdims_example", - "test_reduce_min_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_random", "test_reduce_prod_default_axes_keepdims_example", "test_reduce_prod_default_axes_keepdims_random", - "test_reduce_prod_do_not_keepdims_example", - "test_reduce_prod_do_not_keepdims_random", - "test_reduce_prod_keepdims_example", - "test_reduce_prod_keepdims_random", - "test_reduce_prod_negative_axes_keepdims_example", - "test_reduce_prod_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_random", "test_reduce_sum_default_axes_keepdims_example", "test_reduce_sum_default_axes_keepdims_random", - "test_reduce_sum_do_not_keepdims_example", - "test_reduce_sum_do_not_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_random", "test_reduce_sum_empty_axes_input_noop_example", "test_reduce_sum_empty_axes_input_noop_random", - "test_reduce_sum_keepdims_example", - "test_reduce_sum_keepdims_random", - "test_reduce_sum_negative_axes_keepdims_example", - "test_reduce_sum_negative_axes_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_random", "test_reduce_sum_square_default_axes_keepdims_example", "test_reduce_sum_square_default_axes_keepdims_random", - "test_reduce_sum_square_do_not_keepdims_example", - "test_reduce_sum_square_do_not_keepdims_random", - "test_reduce_sum_square_keepdims_example", - "test_reduce_sum_square_keepdims_random", - "test_reduce_sum_square_negative_axes_keepdims_example", - "test_reduce_sum_square_negative_axes_keepdims_random", - // "test_reflect_pad", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_random", + "opset{7,8,9,10}/test_reflect_pad", "test_relu", - "test_reshape_allowzero_reordered", - "test_reshape_extended_dims", - "test_reshape_negative_dim", - "test_reshape_negative_extended_dims", - "test_reshape_one_dim", - "test_reshape_reduced_dims", - "test_reshape_reordered_all_dims", - "test_reshape_reordered_dims", - "test_reshape_reordered_last_dims", - "test_reshape_zero_and_negative_dim", - "test_reshape_zero_dim", - "test_resize_downsample_linear", - "test_resize_downsample_nearest", - "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // tests "test_reshape*" are excluded because the shape is non-constant input. + // "test_reshape_allowzero_reordered", + // "test_reshape_extended_dims", + // "test_reshape_negative_dim", + // "test_reshape_negative_extended_dims", + // "test_reshape_one_dim", + // "test_reshape_reduced_dims", + // "test_reshape_reordered_all_dims", + // "test_reshape_reordered_dims", + // "test_reshape_reordered_last_dims", + // "test_reshape_zero_and_negative_dim", + // "test_reshape_zero_dim", + // tests "test_resize*" are excluded because scales and sizes are non-constant inputs. + // "test_resize_downsample_linear", + // "test_resize_downsample_nearest", + // "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", // "test_resize_downsample_scales_cubic_align_corners", - "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_cubic", // "test_resize_downsample_scales_linear_align_corners", - "test_resize_downsample_scales_linear", - "test_resize_downsample_scales_nearest", - "test_resize_downsample_sizes_cubic", - "test_resize_downsample_sizes_linear_pytorch_half_pixel", - "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", - "test_resize_downsample_sizes_nearest", - "test_resize_nearest", - "test_resize_tf_crop_and_resize", - "test_resize_upsample_linear", - "test_resize_upsample_nearest", - "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", - "test_resize_upsample_scales_cubic_align_corners", - "test_resize_upsample_scales_cubic_asymmetric", - "test_resize_upsample_scales_cubic", - "test_resize_upsample_scales_linear_align_corners", - "test_resize_upsample_scales_linear", - "test_resize_upsample_scales_nearest", - "test_resize_upsample_sizes_cubic", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", - "test_resize_upsample_sizes_nearest", + // "test_resize_downsample_scales_linear", + // "test_resize_downsample_scales_nearest", + // "test_resize_downsample_sizes_cubic", + // "test_resize_downsample_sizes_linear_pytorch_half_pixel", + // "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + // "test_resize_downsample_sizes_nearest", + // "test_resize_nearest", + // "test_resize_tf_crop_and_resize", + // "test_resize_upsample_linear", + // "test_resize_upsample_nearest", + // "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_upsample_scales_cubic_align_corners", + // "test_resize_upsample_scales_cubic_asymmetric", + // "test_resize_upsample_scales_cubic", + // "test_resize_upsample_scales_linear_align_corners", + // "test_resize_upsample_scales_linear", + // "test_resize_upsample_scales_nearest", + // "test_resize_upsample_sizes_cubic", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + // "test_resize_upsample_sizes_nearest", // // "test_reversesequence_batch", // // "test_reversesequence_time", // // "test_rnn_seq_length", // // "test_roialign_aligned_false", // // "test_roialign_aligned_true", // // "test_roialign", - // // "test_round", + "test_round", // // "test_scan_sum", // // "test_scan9_sum", - "test_scatter_elements_with_axis", - "test_scatter_elements_with_duplicate_indices", - "test_scatter_elements_with_negative_indices", - "test_scatter_elements_without_axis", + // "test_scatter_elements_with_axis", // TFLite backend does not support non-constant indices. + // "test_scatter_elements_with_duplicate_indices", // WebNN only supports reduction type 'none'. + // "test_scatter_elements_with_negative_indices", // TFLite backend does not support non-constant indices. + // "test_scatter_elements_without_axis", // TFLite backend does not support non-constant indices. // // "test_scatter_with_axis", // // "test_scatter_without_axis", - "test_scatternd_add", - "test_scatternd_multiply", + // "test_scatternd_add", // WebNN only supports reduction type 'none'. + // "test_scatternd_multiply", // WebNN only supports reduction type 'none'. "test_scatternd", // // "test_sce_mean_3d_expanded", // // "test_sce_mean_3d_log_prob_expanded", @@ -2365,14 +2372,15 @@ // "test_sinh", // // "test_size_example", // // "test_size", - "test_slice_default_axes", - "test_slice_default_steps", - "test_slice_end_out_of_bounds", - "test_slice_neg_steps", - "test_slice_neg", - "test_slice_negative_axes", - "test_slice_start_out_of_bounds", - "test_slice", + // tests "test_slice_*" on opset > 9 are excluded because starts, ends, axes and steps are non-constant inputs. + "opset{7,8,9}/test_slice_default_axes", + // "test_slice_default_steps", + "opset{7,8,9}/test_slice_end_out_of_bounds", + // "test_slice_neg_steps", + "opset{7,8,9}/test_slice_neg", + // "test_slice_negative_axes", + "opset{7,8,9}/test_slice_start_out_of_bounds", + "opset{7,8,9}/test_slice", "test_softmax_axis_0_expanded", "test_softmax_axis_0", "test_softmax_axis_1_expanded", @@ -2455,23 +2463,24 @@ "test_softmax_large_number", "test_softmax_negative_axis_expanded", "test_softmax_negative_axis", - // // "test_softplus_example", - // // "test_softplus", - // // "test_softsign_example", - // // "test_softsign", + "test_softplus_example", + "test_softplus", + "test_softsign_example", + "test_softsign", // "test_spacetodepth_example", // "test_spacetodepth", - "test_split_equal_parts_1d", - "test_split_equal_parts_2d", + // tests "test_split_*" on opset > 10 are excluded because the split input is non-constant input. + "opset{7,8,9,10}/test_split_equal_parts_1d", + "opset{7,8,9,10}/test_split_equal_parts_2d", "test_split_equal_parts_default_axis", - "test_split_variable_parts_1d", - "test_split_variable_parts_2d", - "test_split_variable_parts_default_axis", - "test_split_zero_size_splits", + "opset{7,8,9,10}/test_split_variable_parts_1d", + "opset{7,8,9,10}/test_split_variable_parts_2d", + "opset{7,8,9,10}/test_split_variable_parts_default_axis", + // "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", - "test_squeeze_negative_axes", - "test_squeeze", + // "test_squeeze_negative_axes", + "opset{7,8,9,10}/test_squeeze", // // "test_stft_with_window", // // "test_stft", // // "test_strnormalizer_export_monday_casesensintive_lower", @@ -2482,7 +2491,7 @@ // // "test_strnormalizer_nostopwords_nochangecase", "test_sub_bcast", "test_sub_example", - // "test_sub_uint8", + "test_sub_uint8", "test_sub", // "test_sum_example", // "test_sum_one_input", @@ -2501,8 +2510,9 @@ // "test_thresholdedrelu_default", // "test_thresholdedrelu_example", // "test_thresholdedrelu", - "test_tile_precomputed", - "test_tile", + // tests "test_tile*" are excluded because the repeats is non-constant input. + // "test_tile_precomputed", + // "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", @@ -2520,41 +2530,41 @@ "test_transpose_all_permutations_5", "test_transpose_default", // "test_tril_neg", - // "test_tril_one_row_neg", + "test_tril_one_row_neg", // "test_tril_out_neg", // "test_tril_out_pos", // "test_tril_pos", // "test_tril_square_neg", - // "test_tril_square", + "test_tril_square", // "test_tril_zero", - // "test_tril", + "test_tril", // "test_triu_neg", // "test_triu_one_row", // "test_triu_out_neg_out", // "test_triu_out_pos", // "test_triu_pos", // "test_triu_square_neg", - // "test_triu_square", + "test_triu_square", // "test_triu_zero", - // "test_triu", + "test_triu", // // "test_unique_not_sorted_without_axis", // // "test_unique_sorted_with_axis_3d", // // "test_unique_sorted_with_axis", // // "test_unique_sorted_with_negative_axis", // // "test_unique_sorted_without_axis", - "test_unsqueeze_axis_0", - "test_unsqueeze_axis_1", - "test_unsqueeze_axis_2", - "test_unsqueeze_axis_3", - "test_unsqueeze_negative_axes", - "test_unsqueeze_three_axes", - "test_unsqueeze_two_axes", - "test_unsqueeze_unsorted_axes", - "test_unsqueeze", + // "test_unsqueeze_axis_0", + // "test_unsqueeze_axis_1", + // "test_unsqueeze_axis_2", + // "test_unsqueeze_axis_3", + // "test_unsqueeze_negative_axes", + // "test_unsqueeze_three_axes", + // "test_unsqueeze_two_axes", + // "test_unsqueeze_unsorted_axes", + "opset{7,8,9,10}/test_unsqueeze", // "test_wrap_pad" // "test_upsample_nearest", "test_where_example", - // "test_where_long_example", + "test_where_long_example", "test_xor_bcast3v1d", "test_xor_bcast3v2d", "test_xor_bcast4v2d", diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 9b9d755498366..a5ab63d74df24 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -522,6 +522,9 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) { + if (context.IsGraphCaptureEnabled()) { + ORT_NOT_IMPLEMENTED("Graph capture not implemented for non flash attention path"); + } const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 00f60142df159..606dbfde15c2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,32 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); + const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); + const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + + const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); + const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); + const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform); + + if (prepare_indirect_dispatch_) { + sh.AddOutput("indirect_buffer", ShaderUsage::None); + } + + return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", + WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_), + WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), + WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(query, query), + WGSL_TEMPLATE_VARIABLE(seqlens, seqlens), + WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache)); +} + Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Expectations are // qkv have same number of heads and hidden dimension (head size). @@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, + const Tensor* cos_cache, const Tensor* sin_cache) { + constexpr uint32_t tile_size = 64; + // Extract present_sequence_length directly from present_key tensor shape: // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Declare query_output at function scope to ensure it persists throughout the function + Tensor query_output; + + // Create indirect dispatch buffer if using indirect dispatch + Tensor* indirect_buffer_ptr = nullptr; + Tensor indirect_buffer; + + // Prepare indirect dispatch buffer for decode path with static KV cache + const bool use_indirect_dispatch = parameters.sequence_length_ == 1 && + parameters.past_present_share_buffer_ && + seqlen_k != nullptr && + context.IsGraphCaptureEnabled(); + if (use_indirect_dispatch) { + const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions + indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); + indirect_buffer_ptr = &indirect_buffer; + } + + const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr); + + if (do_rotary) { + ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input."); + ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache."); + + // Q points to the packed QKV tensor in this case, create query output tensor + query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); + + ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters, + Q, seqlen_k, + cos_cache, sin_cache, + &query_output, present_key, present_value, + indirect_buffer_ptr, tile_size)); + Q = &query_output; + } else { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + } + if (parameters.sequence_length_ > 1) { - const uint32_t tile_size = 64; - // For encode path, use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr)); bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.sequence_length_, present_sequence_length}); const TensorShape qk_shape(qk_dims); Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); - constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; - // Determine if we should use indirect dispatch - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && - context.IsGraphCaptureEnabled(); - - // Create indirect dispatch buffer if using indirect dispatch - Tensor* indirect_buffer_ptr = nullptr; - Tensor indirect_buffer; - if (use_indirect_dispatch) { - const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions - indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); - indirect_buffer_ptr = &indirect_buffer; - // Use the fused CopyKVCache that also prepares the indirect dispatch buffer - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr)); - } else { - // Use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); - } - // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, 2}); @@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } +Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, + const WebgpuAttentionParameters& params, + const Tensor* packedQKV, + const Tensor* seqlen_k, + const Tensor* cos_cache, + const Tensor* sin_cache, + Tensor* query, + Tensor* present_key, + Tensor* present_value, + Tensor* indirect_buffer, + uint32_t tile_size) { + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = params.head_size_; + + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + // = head_size - half_rotary_dim + const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + + // Extract present_sequence_length from present_key tensor shape + const uint32_t present_sequence_length = gsl::narrow_cast(present_key->Shape()[2]); + + const bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch); + program + .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) + .AddInputs({ + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, + }); + program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {present_key, ProgramTensorMetadataDependency::None, components}, + {present_value, ProgramTensorMetadataDependency::None, components}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + program.AddUniformVariables({ + {static_cast(params.sequence_length_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, + {static_cast(params.num_heads_)}, + {static_cast(params.kv_num_heads_)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, + {present_sequence_length}, + {tile_size}, + {static_cast(dispatch_size)}, + }); + + program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9599c10533351..a936a91695921 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -15,6 +15,32 @@ namespace webgpu { using namespace onnxruntime::webgpu; +class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { + public: + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch) + : Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"}, + interleaved_(interleaved), + prepare_indirect_dispatch_(prepare_indirect_dispatch) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"kv_num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"half_rotary_dim", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"dispatch_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; + const bool prepare_indirect_dispatch_; +}; + class CopyKVCacheProgram final : public Program { public: CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, @@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramShape().Size(); program - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) .AddUniformVariables({ {static_cast(params.hidden_size_)}, {static_cast(params.kv_hidden_size_)}, @@ -90,32 +90,46 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) // = head_size - half_rotary_dim - const auto work_per_head = head_size - half_rotary_embedding_dim; - auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec); SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_); program .CacheHint(params.rotary_interleaved_) - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) .AddInputs({ - {seqlen_k, ProgramTensorMetadataDependency::Rank}, - {cos_cache, ProgramTensorMetadataDependency::Rank}, - {sin_cache, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, - {key, ProgramTensorMetadataDependency::Rank}, - {val, ProgramTensorMetadataDependency::Rank}}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {key, ProgramTensorMetadataDependency::None, components}, + {val, ProgramTensorMetadataDependency::None, components}}) .AddUniformVariables({ {static_cast(params.sequence_length_)}, - {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size)}, - {half_rotary_embedding_dim}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); @@ -177,15 +191,15 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, program .CacheHint(params.rotary_interleaved_) .AddInputs({ - {query_in, ProgramTensorMetadataDependency::Rank}, + {query_in, ProgramTensorMetadataDependency::TypeAndRank}, {key_in, ProgramTensorMetadataDependency::Rank}, - {seqlen_k, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, {cos_cache, ProgramTensorMetadataDependency::Rank}, {sin_cache, ProgramTensorMetadataDependency::Rank}, }) .AddOutputs({ - {query_out, ProgramTensorMetadataDependency::Rank}, - {key_out, ProgramTensorMetadataDependency::Rank}, + {query_out, ProgramTensorMetadataDependency::None}, + {key_out, ProgramTensorMetadataDependency::None}, }) .SetDispatchGroupSize((q_domain_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ @@ -265,7 +279,26 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor qRotary; Tensor kRotary; + + // Use a sliding window if the total sequence exceeds the window's length. + bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); + bool will_use_flash_attention = false; + if (head_sink == nullptr && !use_smooth_softmax_ && !use_sliding_window) { + // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking + WebgpuAttentionParameters temp_params = parameters; + temp_params.is_packed_qkv_ = false; + will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context); + } + if (parameters.is_packed_qkv_ && do_rotary_) { + // Use the ultimate fused operation when FlashAttention and static KV cache is enabled. + if (will_use_flash_attention && parameters.past_present_share_buffer_) { + // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled + // query points to packed QKV, K and V are nullptr since they're not needed + return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context, seqlen_k, cos_cache, sin_cache); + } + // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -279,8 +312,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& key = &kSplit; value = &vSplit; } else { - // Original separate path if (parameters.is_packed_qkv_) { + // splitQKV qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -292,6 +325,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& value = &vSplit; } if (do_rotary_) { + // rotary QK qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters, @@ -304,11 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } } - // Use a sliding window if the total sequence exceeds the window's length. - bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); - if (head_sink == nullptr && !use_smooth_softmax_ && - !use_sliding_window && - CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { + if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context, seqlen_k); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index b64448611079f..777be41ffb456 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -36,11 +36,11 @@ $MAIN { // Calculate actual indices in the head for i and j #if interleaved - let idx_i = in_head_idx; - let idx_j = in_head_idx + 1u; + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; #else let idx_i = in_head_idx; - let idx_j = in_head_idx + uniforms.half_rotary_dim; + let idx_j = idx_i + uniforms.half_rotary_dim; #endif // Process Q pair diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template new file mode 100644 index 0000000000000..d6cb654afa756 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -0,0 +1,111 @@ +#param interleaved +#param prepare_indirect_dispatch + +#use guardAgainstOutOfBoundsWorkgroupSizes +#use .setByIndices .getByIndices .getByOffset + +$MAIN { + guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size); + + // Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + let work_per_head = uniforms.head_size - uniforms.half_rotary_dim; + let total_work = uniforms.num_heads * work_per_head; + + let batch_idx = global_idx / (uniforms.sequence_length * total_work); + let remainder1 = global_idx % (uniforms.sequence_length * total_work); + let seq_idx = remainder1 / total_work; + let remainder2 = remainder1 % total_work; + let head_idx = remainder2 / work_per_head; + let in_head_idx = remainder2 % work_per_head; + + // Calculate base offset in packed_qkv for this token + // Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)] + let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; + let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size; + + // Calculate position_id (needed for rotary embedding) + let seqlen_i = seqlens.getByOffset(batch_idx); + let seqlen = u32(seqlen_i); + let total_seqlen = seqlen + 1u; + + let past_seqlen = total_seqlen - uniforms.sequence_length; + // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value + let position_id = past_seqlen + seq_idx; + +#if prepare_indirect_dispatch + // Prepare indirect dispatch buffer for thread 0 + if (global_idx == 0u) { + let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + indirect_buffer[0] = num_total_seq_length_tile; + indirect_buffer[1] = uniforms.num_heads; + indirect_buffer[2] = 1u; + } +#endif + + if (in_head_idx < uniforms.half_rotary_dim) { + // Process a rotary pair (i, j) + let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx)); + let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx)); + + // Calculate actual indices in the head for i and j +#if interleaved + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; +#else + let idx_i = in_head_idx; + let idx_j = idx_i + uniforms.half_rotary_dim; +#endif + + // Process Q pair + let q_base = base_offset + head_idx * uniforms.head_size; + let q_i_offset = q_base + idx_i; + let q_j_offset = q_base + idx_j; + let q_i = packed_qkv.getByOffset(q_i_offset); + let q_j = packed_qkv.getByOffset(q_j_offset); + let q_re = q_i * cos_v - q_j * sin_v; + let q_im = q_i * sin_v + q_j * cos_v; + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im); + + // Process K and V pairs if within kv_num_heads + if (head_idx < uniforms.kv_num_heads) { + let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size; + let k_i_offset = k_base + idx_i; + let k_j_offset = k_base + idx_j; + let k_i = packed_qkv.getByOffset(k_i_offset); + let k_j = packed_qkv.getByOffset(k_j_offset); + let k_re = k_i * cos_v - k_j * sin_v; + let k_im = k_i * sin_v + k_j * cos_v; + // Write K directly to present_key cache + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), k_re); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), k_im); + + // V doesn't need rotary, just copy the pair to present_value cache + let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size; + let v_i = packed_qkv.getByOffset(v_base + idx_i); + let v_j = packed_qkv.getByOffset(v_base + idx_j); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), v_i); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), v_j); + } + } else { + // Process non-rotary elements (direct copy) + let actual_idx = uniforms.half_rotary_dim + in_head_idx; + + // Copy Q + let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx; + let q_data = packed_qkv.getByOffset(q_offset); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data); + + // Copy K and V if within kv_num_heads directly to present cache + if (head_idx < uniforms.kv_num_heads) { + let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx; + let k_data = packed_qkv.getByOffset(k_offset); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), k_data); + + let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx; + let v_data = packed_qkv.getByOffset(v_offset); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), v_data); + } + } +} // MAIN diff --git a/onnxruntime/core/common/cpuid_info_vendor.cc b/onnxruntime/core/common/cpuid_info_vendor.cc index d4d940eedfe28..8675f129da770 100644 --- a/onnxruntime/core/common/cpuid_info_vendor.cc +++ b/onnxruntime/core/common/cpuid_info_vendor.cc @@ -198,7 +198,7 @@ constexpr std::array kCpuVendorInfos{ CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, - + CpuVendorInfo{cpuinfo_vendor_ibm, "IBM", 0x1014}, // TODO add more as needed }; @@ -228,6 +228,9 @@ void CPUIDInfo::VendorInfoInit() { } } #endif // defined(CPUINFO_SUPPORTED) +#if defined(_AIX) + result = cpuinfo_vendor_ibm; +#endif return result; }(); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 91b5b811a3529..a656abb098911 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -58,6 +58,18 @@ Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_mem)); } + if (auto it = kvps_entries.find(ConfigKeyNames::UseCudaMemPool); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.use_cuda_mempool)); + } + + if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolReleaseThreshold); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_release_threshold)); + } + + if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolBytesToKeepOnShrink); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_bytes_to_keep_on_shrink)); + } + if (!cfg.IsValid()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid arena configuration. Please check the values provided."); @@ -177,6 +189,16 @@ void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve return alloc.Alloc(size); } + +IArena* IArena::SafeArenaCast(IAllocator* allocator) { +#if !defined(ORT_NO_RTTI) + auto* result = dynamic_cast(allocator); + return result; +#else + return static_cast(allocator); +#endif +} + } // namespace onnxruntime std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); } diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index 8c4e74c4b1cc7..ee9cf5bb39ca0 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -52,14 +52,14 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { if (info.use_stream_aware_arena) { #ifdef ORT_ENABLE_STREAM return AllocatorPtr( - std::make_unique(std::move(device_allocator), - max_mem, - arena_extend_str, - initial_chunk_size_bytes, - max_dead_bytes_per_chunk, - initial_growth_chunk_size_bytes)); + std::make_unique(std::move(device_allocator), + max_mem, + arena_extend_str, + initial_chunk_size_bytes, + max_dead_bytes_per_chunk, + initial_growth_chunk_size_bytes)); #else - ORT_THROW("StreamAwareArena should be transparent to minimal build."); + ORT_THROW("StreamAwareBFCArena should be transparent to minimal build."); #endif } else { return AllocatorPtr( diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 3a5af42d03cdd..cfe155986eff2 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -13,11 +13,10 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, int64_t max_power_of_two_extend_bytes) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), - OrtAllocatorType::OrtArenaAllocator, - resource_allocator->Info().device, - resource_allocator->Info().mem_type)), - arena_type_(ArenaType::BaseArena), + : IArena(OrtMemoryInfo(resource_allocator->Info().name.c_str(), + OrtAllocatorType::OrtArenaAllocator, + resource_allocator->Info().device, + resource_allocator->Info().mem_type)), device_allocator_(std::move(resource_allocator)), free_chunks_list_(kInvalidChunkHandle), next_allocation_id_(1), @@ -827,13 +826,13 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla } } -StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocator, - size_t total_memory, - ArenaExtendStrategy arena_extend_strategy, - int initial_chunk_size_bytes, - int max_dead_bytes_per_chunk, - int initial_growth_chunk_size_bytes, - int64_t max_power_of_two_extend_bytes) +StreamAwareBFCArena::StreamAwareBFCArena(std::unique_ptr resource_allocator, + size_t total_memory, + ArenaExtendStrategy arena_extend_strategy, + int initial_chunk_size_bytes, + int max_dead_bytes_per_chunk, + int initial_growth_chunk_size_bytes, + int64_t max_power_of_two_extend_bytes) : BFCArena(std::move(resource_allocator), total_memory, arena_extend_strategy, @@ -841,14 +840,13 @@ StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocato max_dead_bytes_per_chunk, initial_growth_chunk_size_bytes, max_power_of_two_extend_bytes) { - arena_type_ = ArenaType::StreamAwareArena; } -void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) { +void* StreamAwareBFCArena::AllocOnStream(size_t size, Stream* current_stream) { return AllocateRawInternal(size, false, current_stream); } -void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { +void StreamAwareBFCArena::ReleaseStreamBuffers(Stream* stream) { // since chunks on target stream will be reset to nullptr, trigger coalesce to see whether we can get bigger chunk. ResetChunkOnTargetStream(stream, true); } diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index f3c0544124241..e3494853f7064 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -43,7 +43,7 @@ namespace onnxruntime { #endif #endif -class StreamAwareArena; +class StreamAwareBFCArena; // A memory allocator that implements a 'best-fit with coalescing' // algorithm. This is essentially a very simple version of Doug Lea's // malloc (dlmalloc). @@ -52,7 +52,7 @@ class StreamAwareArena; // coalescing. One assumption we make is that the process using this // allocator owns pretty much all of the memory, and that nearly // all requests to allocate memory go through this interface. -class BFCArena : public IAllocator { +class BFCArena : public IArena { public: static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; @@ -61,11 +61,6 @@ class BFCArena : public IAllocator { static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); - enum ArenaType { - BaseArena, - StreamAwareArena, - }; - BFCArena(std::unique_ptr resource_allocator, size_t total_memory, ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, @@ -84,14 +79,6 @@ class BFCArena : public IAllocator { // If p is NULL, no operation is performed. void Free(void* p) override; - // Frees all allocation regions in which no chunk is in use. - // Does not free any reserved chunks. - // Resets the size that the arena will grow by in the next allocation to - // `initial_growth_chunk_size_bytes_` but ultimately all - // future allocation sizes are determined by the arena growth strategy - // and the allocation request. - Status Shrink(); - void* Reserve(size_t size) override; void GetStats(AllocatorStats* stats) override; @@ -100,7 +87,13 @@ class BFCArena : public IAllocator { size_t AllocatedSize(const void* ptr); - ArenaType GetArenaType() const { return arena_type_; } + // Frees all allocation regions in which no chunk is in use. + // Does not free any reserved chunks. + // Resets the size that the arena will grow by in the next allocation to + // `initial_growth_chunk_size_bytes_` but ultimately all + // future allocation sizes are determined by the arena growth strategy + // and the allocation request. + Status Shrink() override; protected: void* AllocateRawInternal(size_t num_bytes, @@ -112,7 +105,6 @@ class BFCArena : public IAllocator { // perform coalesce if coalesce_flag is true void ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag); #endif - ArenaType arena_type_; private: void DeallocateRawInternal(void* ptr); @@ -510,26 +502,22 @@ class BFCArena : public IAllocator { }; #ifdef ORT_ENABLE_STREAM -class StreamAwareArena : public BFCArena { +class StreamAwareBFCArena : public BFCArena { public: - StreamAwareArena(std::unique_ptr resource_allocator, - size_t total_memory, - ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, - int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, - int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, - int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, - int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); + StreamAwareBFCArena(std::unique_ptr resource_allocator, + size_t total_memory, + ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, + int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, + int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, + int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, + int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); bool IsStreamAware() const override { return true; } // Standard alloc behavior. Returns valid pointer if size > 0 and memory was available. Otherwise returns nullptr. void* AllocOnStream(size_t size, Stream* current_stream_id) override; - void ReleaseStreamBuffers(Stream* stream); - - static StreamAwareArena* FromBFCArena(BFCArena& arena) { - return arena.GetArenaType() == ArenaType::StreamAwareArena ? reinterpret_cast(&arena) : nullptr; - } + void ReleaseStreamBuffers(Stream* stream) override; }; #endif #ifdef __GNUC__ diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index 8d15e03c2e5ce..a32973ddb8c9e 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -21,7 +21,8 @@ struct DummyStream : Stream { class DeviceStreamCollectionImpl { public: - DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) { + DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) + : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) { device_streams_.resize(num_streams, nullptr); owned_streams_.reserve(num_streams); root_stream_ = std::make_unique(nullptr, root_stream_device_); @@ -32,13 +33,16 @@ class DeviceStreamCollectionImpl { void ReleaseSingleStreamBuffers(Stream* stream) { if (!stream) return; - for (auto it : allocators_) { + for (const auto& it : allocators_) { if (it.second->Info().device == stream->GetDevice() && it.second->Info().alloc_type == OrtArenaAllocator) { - auto* arena_alloc = static_cast(it.second.get()); - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_alloc); - if (stream_aware_alloc) { - stream_aware_alloc->ReleaseStreamBuffers(stream); + if (it.second->IsStreamAware()) { + // Previously we only had one StreamAwareBFCArena. We need to guard + // against multiple allocators now. + auto* arena_alloc = IArena::SafeArenaCast(it.second.get()); + if (arena_alloc) { + arena_alloc->ReleaseStreamBuffers(stream); + } } } } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 01ba492eb166e..8fb3dc63aa4d1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -588,7 +588,7 @@ Status SessionState::PrepackConstantInitializedTensors( // within this session. Or if the weight is not present on disk, // we store the newly minted pre-packed data. - AllocatorPtr session_cpu_alloc = GetAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); PrePackedWeights weights_to_be_filled_in; // The reason we invoke PrePack() before looking into the container for any pre-packed weight // cached by another instance of the same op_type (for the same constant initializer) is because @@ -596,7 +596,7 @@ Status SessionState::PrepackConstantInitializedTensors( // pre-packed weight with the pre-packed weight generated by this instance of the same op_type because // other static properties of the node like node attributes could play a role in the pre-packed // weights' contents. - ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, + ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_initializer_alloc, is_packed, &weights_to_be_filled_in)); diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index fb3b1d1d29eec..487e1533f5967 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -332,8 +332,8 @@ static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const const float* weights, const float* bias, MLAS_THREADPOOL* ThreadPool) { - //cache of prepacked kai rhs weights and biases - static std::unordered_map> rhs_cache; + // Cache of prepacked kai rhs weights and biases. thread_local to prevent interference from parallel sessions. + thread_local std::unordered_map> rhs_cache; RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; @@ -474,8 +474,8 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); - //cache of computed lhs ptr offsets - static std::unordered_map> lhs_ptrs_cache; + // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. + thread_local std::unordered_map> lhs_ptrs_cache; std::shared_ptr lhs_ptrs; if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 34b6b2de64a92..aeddef0c5188f 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -556,6 +556,8 @@ class PosixEnv : public Env { } PathString GetRuntimePath() const override { +// In AIX, dladdr is not supported. +#if !defined(_AIX) // Use dladdr() to look up the file that contains an address from this binary. const void* const address_from_this_binary = reinterpret_cast(Env::Default); @@ -568,7 +570,7 @@ class PosixEnv : public Env { runtime_path.remove_filename(); return runtime_path; } - +#endif return PathString{}; } diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index f3c6b18f8e753..dc2cec1852fed 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -28,8 +28,10 @@ ONNX_OPERATOR_KERNEL_EX( 10, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .TypeConstraint("T3", DataTypeImpl::GetTensorType()), ConvInteger); @@ -43,12 +45,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const { if (num_inputs >= 3 && input_defs[2]->Exists()) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); - input_offset = *(X_Zero_Point->Data()); + input_offset = *static_cast(X_Zero_Point->DataRaw()); } if (num_inputs >= 4 && input_defs[3]->Exists()) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); - filter_offset = *(W_Zero_Point->Data()); + filter_offset = *static_cast(W_Zero_Point->DataRaw()); } const int64_t N = X->Shape()[0]; @@ -110,45 +112,82 @@ Status ConvInteger::Compute(OpKernelContext* context) const { concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - const auto* Xdata = X->Data(); - const auto* Wdata = W->Data(); + const auto* Xdata = static_cast(X->DataRaw()); + const auto* Wdata = static_cast(W->DataRaw()); + bool X_is_signed = X->IsDataType(); auto* Ydata = Y->MutableData(); for (int image_id = 0; image_id < N; ++image_id) { for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { if (col_buffer_data != nullptr) { if (kernel_rank == 2) { - math::Im2col()( - Xdata, - C / conv_attrs_.group, - input_shape[0], - input_shape[1], - kernel_shape[0], - kernel_shape[1], - dilations[0], - dilations[1], - pads[0], - pads[1], - pads[2], - pads[3], - strides[0], - strides[1], - col_buffer_data, - input_offset); + if (X_is_signed) { + math::Im2col()( + reinterpret_cast(Xdata), + C / conv_attrs_.group, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + reinterpret_cast(col_buffer_data), + static_cast(input_offset)); + } else { + math::Im2col()( + Xdata, + C / conv_attrs_.group, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + col_buffer_data, + input_offset); + } } else { - math::Im2col()( - Xdata, - input_shape.GetDims().data(), - output_shape.GetDims().data(), - kernel_dim, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_rank), - col_buffer_data, - false, - input_offset); + if (X_is_signed) { + math::Im2col()( + reinterpret_cast(Xdata), + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + reinterpret_cast(col_buffer_data), + false, + static_cast(input_offset)); + } else { + math::Im2col()( + Xdata, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + col_buffer_data, + false, + input_offset); + } } } @@ -156,12 +195,14 @@ Status ConvInteger::Compute(OpKernelContext* context) const { gemm_shape.M = static_cast(M / conv_attrs_.group); gemm_shape.N = static_cast(output_image_size); gemm_shape.K = static_cast(kernel_dim); + gemm_shape.AIsSigned = W->IsDataType(); + gemm_shape.BIsSigned = X_is_signed; MLAS_GEMM_QUANT_DATA_PARAMS gemm_params; gemm_params.A = Wdata + group_id * W_offset; gemm_params.lda = static_cast(kernel_dim); gemm_params.ZeroPointA = filter_offset; - gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data, + gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data; gemm_params.ldb = static_cast(output_image_size); gemm_params.ZeroPointB = &input_offset; gemm_params.C = Ydata; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3816cc1f8f6b9..eff0801a00460 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -14,6 +14,7 @@ #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/cuda_profiler.h" +#include "core/providers/cuda/cuda_mempool_arena.h" #include "core/session/onnxruntime_run_options_config_keys.h" #ifndef USE_CUDA_MINIMAL @@ -134,11 +135,10 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace cuda -AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId device_id, - size_t gpu_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - CUDAExecutionProviderExternalAllocatorInfo external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { +AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params) { + ORT_ENFORCE(cuda_allocator_params.external_alloc_info != nullptr, + "CUDAAllocatorParams.external_alloc_info is nullptr."); + const auto& external_allocator_info = *(cuda_allocator_params.external_alloc_info); if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { @@ -147,24 +147,59 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi external_allocator_info.free, external_allocator_info.empty_cache); }, - device_id, + cuda_allocator_params.device_id, false); return CreateAllocator(default_memory_info); } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, CUDA); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, - // make it stream aware - true); - - // CUDA malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); + const auto* arena_cfg = cuda_allocator_params.arena_cfg; + const bool cuda_mempool_requested = arena_cfg != nullptr && arena_cfg->use_cuda_mempool == 1; + bool use_cuda_mempool = cuda_mempool_requested && cuda::CudaMempoolArena::IsCudaVersionSupported(); + + if (cuda_mempool_requested && !use_cuda_mempool) { + LOGS_DEFAULT(WARNING) + << "CUDA memory pool requested but not supported on this device/driver." + << "Falling back to default BFCArena with CUDA allocator."; + } + + if (use_cuda_mempool) { + const bool cuda_graph_enabled = cuda_allocator_params.provider_info != nullptr && + cuda_allocator_params.provider_info->enable_cuda_graph; + + if (cuda_graph_enabled) { + LOGS_DEFAULT(WARNING) + << "CUDA Mempool Arena allocator is not compatible with requested CUDA Graph Capture" + << "Falling back to default BFCArena with CUDA allocator."; + use_cuda_mempool = false; + } + } + + if (use_cuda_mempool) { + auto device = OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + cuda_allocator_params.device_id); + auto mem_info = OrtMemoryInfo("CUDAMemPoolArena", OrtAllocatorType::OrtArenaAllocator, device, OrtMemTypeDefault); + + auto mempool_allocator = std::make_shared(mem_info, + arena_cfg->cuda_mempool_release_threshold, + arena_cfg->cuda_mempool_bytes_to_keep_on_shrink, + cuda_allocator_params.logger); + + return mempool_allocator; + } else { + AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId id) { + return std::make_unique(id, CUDA); + }, + cuda_allocator_params.device_id, + true, + {arena_cfg ? *arena_cfg + : OrtArenaCfg(cuda_allocator_params.cuda_mem_threshold, + static_cast(cuda_allocator_params.arena_extend_strategy), -1, -1, -1, -1L)}, + // make it stream aware + true); + // CUDA malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); + } } } @@ -3044,9 +3079,18 @@ std::vector CUDAExecutionProvider::CreatePreferredAllocators() { return std::make_unique(device_id, CUDA_PINNED); }, info_.device_id); + + CUDAExecutionProvider::CUDAAllocatorParams params{}; + params.device_id = info_.device_id; + params.cuda_mem_threshold = info_.gpu_mem_limit; + params.arena_extend_strategy = info_.arena_extend_strategy; + params.provider_info = &info_; + params.external_alloc_info = &info_.external_allocator_info; + params.arena_cfg = info_.default_memory_arena_cfg; + params.logger = GetLogger(); + return std::vector{ - CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, - info_.external_allocator_info, info_.default_memory_arena_cfg), + CreateCudaAllocator(params), CreateAllocator(pinned_memory_info), }; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 57fde8146d929..751bbb90f8619 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -103,8 +103,17 @@ class CUDAExecutionProvider : public IExecutionProvider { return CUDAExecutionProviderInfo::ToProviderOptions(info_); } - static AllocatorPtr CreateCudaAllocator(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, - CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); + struct CUDAAllocatorParams { + OrtDevice::DeviceId device_id = 0; + size_t cuda_mem_threshold = std::numeric_limits::max(); + ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo; + const CUDAExecutionProviderInfo* provider_info = nullptr; + const CUDAExecutionProviderExternalAllocatorInfo* external_alloc_info = nullptr; + const OrtArenaCfg* arena_cfg = nullptr; + const logging::Logger* logger = nullptr; + }; + + static AllocatorPtr CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params); ITuningContext* GetTuningContext() const override; diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc new file mode 100644 index 0000000000000..802867ec0d89b --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft. +// Licensed under the MIT License. + +#include "cuda_mempool_arena.h" + +#include + +#include "core/providers/cuda/shared_inc/cuda_call.h" // ORT CudaCall helpers +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace cuda { + +// ======== CudaMempoolArena ======== + +CudaMempoolArena::CudaMempoolArena(const OrtMemoryInfo& memory_info, + uint64_t pool_release_threshold, + size_t bytes_to_keep_on_shrink, + const logging::Logger* logger) + : IArena(memory_info), + pool_release_threshold_(pool_release_threshold), + bytes_to_keep_on_shrink_(bytes_to_keep_on_shrink), + logger_(logger) { + if (logger_ == nullptr) { + logger_ = &::onnxruntime::logging::LoggingManager::DefaultLogger(); + } + + // Create a process-local device memory pool for device_id_. + // 'cudaMemAllocationTypeDevice' (for cudaMemPoolProps.allocType) not clear when it is available + + cudaMemPoolProps props{}; + // Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist + // even though is present in some internet docs. + props.allocType = cudaMemAllocationTypePinned; + props.handleTypes = cudaMemHandleTypeNone; // local to process + props.location.type = cudaMemLocationTypeDevice; // Device memory + props.location.id = this->Info().device.Id(); + + CUDA_CALL_THROW(cudaMemPoolCreate(&pool_, &props)); + + if (pool_release_threshold_ != 0) { + CUDA_CALL_THROW(cudaMemPoolSetAttribute(pool_, cudaMemPoolAttrReleaseThreshold, + &pool_release_threshold_)); + } + + LOGS(*logger_, INFO) << "CudaMempoolArena created on device " << this->Info().device.Id() + << " with pool_release_threshold=" << pool_release_threshold_ + << " bytes_to_keep_on_shrink=" << bytes_to_keep_on_shrink_ << "."; + + // Intentionally DO NOT call cudaDeviceSetMemPool(device_id_, pool_); + // All allocations explicitly target this pool via cudaMallocFromPoolAsync. +} + +CudaMempoolArena::~CudaMempoolArena() { + // 1) Best-effort: enqueue frees for any remaining allocations on their recorded streams. + // No locking by design: destruction implies no concurrent access. + for (auto& kv : alloc_map_) { + void* p = kv.first; + const cudaStream_t s = kv.second.stream; + ORT_IGNORE_RETURN_VALUE(cudaFreeAsync(p, s)); // ignore errors in destructor + } + + // 2) Synchronize all streams we know about (those that ever held allocations). + SyncAllKnownStreams_NoThrow(); + + // Now it is safe to drop our bookkeeping. + alloc_map_.clear(); + stream_map_.clear(); + + // 3) Safety barrier: ensure any frees enqueued on destroyed/unknown streams are completed. + ORT_IGNORE_RETURN_VALUE(cudaDeviceSynchronize()); // ignore errors in destructor + + // 4) Trim to zero and destroy the pool. + if (pool_) { + ORT_IGNORE_RETURN_VALUE(cudaMemPoolTrimTo(pool_, 0)); // best-effort + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool_)); + pool_ = nullptr; + } +} + +void* CudaMempoolArena::Alloc(size_t size) { + if (size == 0) return nullptr; + + void* p = nullptr; + constexpr const cudaStream_t kDefaultStream = static_cast(0); + cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, kDefaultStream); + if (err != cudaSuccess) { + ORT_THROW("CudaMempoolArena::Alloc: cudaMallocFromPoolAsync failed: ", + cudaGetErrorString(err), " (", static_cast(err), "), size=", size); + } + + LOGS(*logger_, VERBOSE) << "CudaMempoolArena::Alloc: allocated " + << size << " bytes at " << p << " on default stream."; + + // In case the default stream is busy. + ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(kDefaultStream)); + + { + std::lock_guard lock(mutex_); + AllocationRecord rec{size, kDefaultStream}; + alloc_map_.emplace(p, rec); + stream_map_[kDefaultStream].insert(p); + + total_allocated_ += size; + in_use_bytes_ += size; + max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_); + max_alloc_size_ = std::max(max_alloc_size_, size); + ++num_allocs_; + } + + return p; +} + +void* CudaMempoolArena::AllocOnStream(size_t size, Stream* stream) { + if (size == 0) return nullptr; + + void* p = nullptr; + const cudaStream_t s = ResolveCudaStream(stream); + + cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, s); + if (err != cudaSuccess) { + ORT_THROW("CudaMempoolArena::AllocOnStream: cudaMallocFromPoolAsync failed on stream=", + reinterpret_cast(s), ": ", + cudaGetErrorString(err), " (", static_cast(err), "), size=", size); + } + + LOGS(*logger_, VERBOSE) << "CudaMempoolArena::AllocOnStream: allocated " + << size << " bytes at " << p << " on stream " + << reinterpret_cast(s) << "."; + + { + std::lock_guard lock(mutex_); + AllocationRecord rec{size, s}; + alloc_map_.emplace(p, rec); + stream_map_[s].insert(p); + + total_allocated_ += size; + in_use_bytes_ += size; + max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_); + max_alloc_size_ = std::max(max_alloc_size_, size); + ++num_allocs_; + } + + return p; +} + +void CudaMempoolArena::Free(void* p) { + if (!p) return; + + cudaStream_t s = static_cast(0); + size_t sz = 0; + + { + std::lock_guard lock(mutex_); + auto it = alloc_map_.find(p); + if (it == alloc_map_.end()) { + // Not owned by this allocator; ignore per ORT convention. + LOGS(*logger_, WARNING) << "CudaMempoolArena::Free: pointer " + << p << " not found in allocation map; ignoring."; + return; + } + + s = it->second.stream; + sz = it->second.bytes; + + alloc_map_.erase(it); + + auto sit = stream_map_.find(s); + if (sit != stream_map_.end()) { + sit->second.erase(p); + if (sit->second.empty()) { + stream_map_.erase(sit); + } + } + + in_use_bytes_ = (sz <= in_use_bytes_) ? (in_use_bytes_ - sz) : 0; + } + + // Ordered free on the stream that allocated p + CUDA_CALL_THROW(cudaFreeAsync(p, s)); +} + +Status CudaMempoolArena::Shrink() { + // Trim the pool; live allocations are not affected. + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemPoolTrimTo(pool_, bytes_to_keep_on_shrink_))); + + size_t current_in_use = 0; + ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrUsedMemCurrent, + ¤t_in_use))); + + // Query current reserved size. cudaMemPoolAttrReservedMemCurrent + size_t reserved_size = 0; + if (CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrReservedMemCurrent, + &reserved_size)) + .IsOK()) { + LOGS(*logger_, INFO) << "CudaMempoolArena::Shrink: pool current_in_use: " << current_in_use + << " reserved size after trim : " << reserved_size << " bytes."; + } else { + LOGS(*logger_, INFO) << "CudaMempoolArena pool has been shrunk; unable to query reserved size."; + } + + // Right-size maps under lock. + std::lock_guard lock(mutex_); + MaybeRehashLocked(); + ++num_arena_shrinkages_; + return Status::OK(); +} + +void CudaMempoolArena::GetStats(AllocatorStats* stats) { + if (!stats) return; + std::lock_guard lock(mutex_); + stats->num_allocs = num_allocs_; + stats->total_allocated_bytes = total_allocated_; + stats->bytes_in_use = in_use_bytes_; + stats->max_bytes_in_use = max_bytes_in_use_; + stats->num_arena_shrinkages = num_arena_shrinkages_; +} + +cudaStream_t CudaMempoolArena::ResolveCudaStream(Stream* stream) noexcept { + if (!stream) return static_cast(0); + return static_cast(stream->GetHandle()); +} + +void CudaMempoolArena::MaybeRehashLocked() { + const size_t alloc_sz = alloc_map_.size(); + const size_t stream_sz = stream_map_.size(); + if (alloc_sz > 0) alloc_map_.reserve(alloc_sz); + if (stream_sz > 0) stream_map_.reserve(stream_sz); +} + +void CudaMempoolArena::SyncAllKnownStreams_NoThrow() { + for (const auto& kv : stream_map_) { + const cudaStream_t s = kv.first; + ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(s)); // ignore errors; device-wide sync follows + } +} + +bool CudaMempoolArena::IsCudaVersionSupported() noexcept { + int ort_cuda_rt_version = 0; + cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_rt_version < 11020) { + return false; + } + + int ort_cuda_driver_version = 0; + cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_driver_version < 11020) { + return false; + } + + // Check if the driver version supports the runtime version + if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) { + return false; + } + + if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) { + return false; + } + + return true; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h new file mode 100644 index 0000000000000..750cbaf93b6d4 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" // ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE, ORT_THROW/ENFORCE +#include "core/common/inlined_containers.h" // InlinedHashMap, InlinedHashSet, InlinedVector +#include "core/providers/cuda/cuda_stream_handle.h" // ORT Stream -> cudaStream_t +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace logging { +class Logger; +} +namespace cuda { +/** + * @brief Stream-aware CUDA allocator implemented on top of a private `cudaMemPool_t`. + * The purpose of this arena is to assist with memory allocations in environments where + * a single process is hosting more than one cuda session. This arena hosts cuda memory pool + * which has some tunable parameters to control its memory usage and de-allocates memory back to + * the device according to the specified params. This is opposite to the BFCArena which only + * attempts to free memory on Shrink() if configured at the end of the run. + * + * ### Behavior + * - Creates a **process-local** CUDA mempool for a specific device (from `OrtMemoryInfo`). + * - All allocations use **`cudaMallocFromPoolAsync()`** on either the legacy default stream (0) or a + * caller-provided stream. The allocation stream is recorded for ordered free. + * - `Free()` enqueue **`cudaFreeAsync()`** on the recorded stream to + * respect CUDA's stream-ordered semantics. + * - `Shrink()` trims the pool with **`cudaMemPoolTrimTo(bytes_to_keep)`** and right-sizes the book-keeping maps + * under lock. + * + * ### Tuning + * - `pool_release_threshold`: if non-zero, sets `cudaMemPoolAttrReleaseThreshold`. **Recommended: 1 MB.**, but + * must be experimentally determined based on workload for optimal memory consumption vs performance. + * `cudaMemPoolAttrReservedMemCurrent`. **Recommended: 10 MB.** + * - `bytes_to_keep_on_shrink`: target size for `cudaMemPoolTrimTo()` on `Shrink()`. This is only relevant + * if Shrink() is enabled. It usually costs performance, and strictly speaking is not necessary for cuda mempools + * since they release memory on at synchronous points according to `pool_release_threshold`. + * + * ### Thread-safety + * - All updates to internal maps and statistics are guarded by an internal `std::mutex`. + * + * @note The allocator **does not** set the device default mempool and **does not** switch the current device. + */ +class CudaMempoolArena final : public IArena { + public: + /** + * @brief Construct a `CudaMempoolArena` with a private CUDA mempool. + * + * @param memory_info `OrtMemoryInfo` whose device id selects the CUDA device. + * @param pool_release_threshold Optional release threshold (bytes) for `cudaMemPoolAttrReleaseThreshold`. + * If 0, the attribute is not set. **Recommended value: 1 MB.** + * @param bytes_to_keep_on_shrink Target size (bytes) for `cudaMemPoolTrimTo()` on `Shrink()`. + * @param logger Cuda EP Logger + * + * The created pool is process-local and is **not** set as the device default pool. + */ + CudaMempoolArena(const OrtMemoryInfo& memory_info, + uint64_t pool_release_threshold, + size_t bytes_to_keep_on_shrink, + const logging::Logger* logger); + + /** + * @brief Destructor: + * 1) Enqueues cudaFreeAsync() for any outstanding allocations. + * 2) Synchronizes all known streams (best-effort; ignores invalid handles). + * 3) Calls cudaDeviceSynchronize() as a final barrier to ensure queued frees complete. + * 4) Trims pool to zero and destroys it. + */ + ~CudaMempoolArena() override; + + // -------- IAllocator overrides -------- + + /** + * @brief Allocate @p size bytes using the legacy default CUDA stream (0). + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* Alloc(size_t size) override; + + /** + * @brief Allocate @p size bytes on the given ORT stream (uses `cudaMallocFromPoolAsync`). + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* AllocOnStream(size_t size, Stream* stream) override; + + /** + * @brief Enqueue an ordered async free on the stream that allocated @p p. + * No-op if @p p is null or not owned by this allocator. + */ + void Free(void* p) override; + + /** + * @brief Reserve @p size bytes; implemented in terms of `Alloc(size)`. + * This is done so all the memory is gone including initializers when + * the session is torn down. + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* Reserve(size_t size) override { return Alloc(size); } + + /// @brief This allocator is stream-aware. + bool IsStreamAware() const override { return true; } + + /// @brief Populate basic allocation statistics. + void GetStats(AllocatorStats* stats) override; + + // -------- IArena overrides -------- + + /** + * @brief Enqueue `cudaFreeAsync()` for all allocations made on @p stream. + * we intentionally do not implement this method. The call to this method + * will yank memory from under live OrtValues such as allocated for output + * bound and the resulting output OrtValue will not be valid. + * Then when the OrtValues attempt to release memory those entries are not found + * in the map: CudaMempoolArena::Free: pointer 0000000203800400 not found in allocation map; ignoring + * The reason this works with BFCArena is because it does not really release memory. + */ + // void ReleaseStreamBuffers(Stream* stream) override; + + /** + * @brief Trim the pool to `bytes_to_keep_on_shrink_` (configured at construction) using `cudaMemPoolTrimTo()`. + * Memory still allocated is not affected. Shrink() may affect your performance and contrary to BFCArena + * This allocator does not need Shrink. Cuda mempool is capable of releasing memory automatically + * according to pool_release_threshold_ set at construction. + * Also rehashes internal maps under lock to keep them reasonably sized. + */ + Status Shrink() override; + + static bool IsCudaVersionSupported() noexcept; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudaMempoolArena); + + private: + /// Convert ORT `Stream*` to native `cudaStream_t`; null means legacy default (0). + static cudaStream_t ResolveCudaStream(Stream* stream) noexcept; + + /// Rehash internal maps under lock; invoked only by `Shrink()`. + void MaybeRehashLocked(); + + /// Best-effort synchronization of all streams in stream_map_. Non-throwing; ignores errors. + void SyncAllKnownStreams_NoThrow(); + + struct AllocationRecord { + size_t bytes; + cudaStream_t stream; // stream on which allocation/free are ordered + }; + + // ---- Pool/context configuration (immutable) ---- + uint64_t pool_release_threshold_; + size_t bytes_to_keep_on_shrink_; + size_t initial_pool_size_bytes_; + const logging::Logger* logger_; + cudaMemPool_t pool_{nullptr}; + + // ---- Bookkeeping (guarded by mutex_) ---- + std::mutex mutex_; + InlinedHashMap alloc_map_; // ptr -> record + InlinedHashMap> stream_map_; // stream -> ptrs + + // ---- Stats (guarded by mutex_) ---- + size_t total_allocated_ = 0; + size_t in_use_bytes_ = 0; + size_t max_bytes_in_use_ = 0; + size_t num_allocs_ = 0; + size_t num_arena_shrinkages_ = 0; + size_t max_alloc_size_ = 0; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 3b361f155831b..70afba320576b 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -181,7 +181,13 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { } std::shared_ptr CreateCudaAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::CUDAExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return CUDAExecutionProvider::CreateCudaAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + CUDAExecutionProvider::CUDAAllocatorParams params{}; + params.device_id = device_id; + params.cuda_mem_threshold = gpu_mem_limit; + params.arena_extend_strategy = arena_extend_strategy; + params.external_alloc_info = &external_allocator_info; + params.arena_cfg = default_memory_arena_cfg; + return CUDAExecutionProvider::CreateCudaAllocator(params); } } g_info; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc new file mode 100644 index 0000000000000..619e3eaf5fad4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Helper function to extract value from raw data based on QNN data type +static Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value) { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_16: { + value = static_cast(reinterpret_cast(raw_ptr)->ToFloat()); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " not supported."); + } + return Status::OK(); +} + +// Helper function to extract a scalar float value from a constant initializer +// Handles both float and quantized (INT type) constant inputs +static std::optional GetConstantInitializerFloatScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const auto& name = io_def.node_arg.Name(); + + if (!graph_viewer.IsConstantInitializer(name, true)) { + return std::nullopt; + } + + // Get tensor info to check if it's quantized + TensorInfo tensor_info = {}; + if (!qnn_model_wrapper.GetTensorInfo(io_def, tensor_info).IsOK()) { + return std::nullopt; + } + + // Must be an initializer + if (!tensor_info.is_initializer || !tensor_info.initializer_tensor) { + return std::nullopt; + } + + // Unpack the initializer data + std::vector unpacked_tensor; + if (!qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor).IsOK()) { + return std::nullopt; + } + + if (unpacked_tensor.empty()) { + return std::nullopt; + } + + // Extract the value using GetValueOnQnnDataType + double extracted_value = 0.0; + if (!GetValueOnQnnDataType(tensor_info.qnn_data_type, unpacked_tensor.data(), extracted_value).IsOK()) { + return std::nullopt; + } + + // Check if quantized and dequantize if needed + const bool is_quantized = tensor_info.quant_param.IsQuantized(); + if (is_quantized) { + // For quantized tensors, dequantize the value + if (!tensor_info.quant_param.IsPerTensor()) { + return std::nullopt; // Only support per-tensor quantization + } + + const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); + double dequantized_value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + extracted_value); + return static_cast(dequantized_value); + } + + // For non-quantized tensors, return the extracted value directly + return static_cast(extracted_value); +} + +// Helper function to check if a constant initializer has the expected float value +static bool IsInitializerWithExpectedValue(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def, + float expected_value, + float tolerance = 1e-5f) { + std::optional actual_value = GetConstantInitializerFloatScalar(qnn_model_wrapper, io_def); + if (!actual_value.has_value()) { + return false; + } + + // Compare with expected value within tolerance + return std::fabs(actual_value.value() - expected_value) <= tolerance; +} + +// Forward declaration. +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate); + +// Helper function to validate on QNN +static Status ValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, true); +} + +// Helper function to create on QNN +static Status CreateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, false); +} + +// Gets the parent and child of the Erf node. Can handle the following sequences +// - Parent -> Erf -> Child. +// - Parent -> DQ -> Erf -> Q -> Child. +// +// Also returns the outputs of the Erf. For the sequence `DQ -> Erf -> Q`, returns the outputs of the Q. +static bool GetErfParentAndChild(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + /*out*/ const NodeUnit*& parent_node_unit, + /*out*/ const NodeUnit*& child_node_unit, + /*out*/ const NodeUnit*& dq_node_unit, + /*out*/ const NodeUnit*& q_node_unit, + /*out*/ gsl::span& erf_outputs) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + auto get_first_parent = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& inputs = node_unit.Inputs(); + if (inputs.empty()) { + return nullptr; + } + return GetParentOfInput(graph_viewer, node_unit, inputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + auto get_first_child = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& outputs = node_unit.Outputs(); + if (outputs.empty()) { + return nullptr; + } + + return GetOnlyChildOfOutput(graph_viewer, node_unit, outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + const NodeUnit* erf_parent_node_unit = get_first_parent(erf_node_unit); + if (erf_parent_node_unit == nullptr) { + return false; + } + + const NodeUnit* erf_child_node_unit = get_first_child(erf_node_unit); + if (erf_child_node_unit == nullptr) { + return false; + } + + if (erf_node_unit.UnitType() == NodeUnit::Type::SingleNode && + erf_parent_node_unit->OpType() == "DequantizeLinear" && + erf_child_node_unit->OpType() == "QuantizeLinear") { + // This is the explicit sequence DQ -> Erf -> Q. + // Look past the DQ and Q nodes to get the actual parent and child. + // We do this because ORT utils do not automatically group DQ -> Erf -> Q into a NodeUnit. + dq_node_unit = erf_parent_node_unit; + q_node_unit = erf_child_node_unit; + erf_parent_node_unit = get_first_parent(*erf_parent_node_unit); + erf_child_node_unit = get_first_child(*erf_child_node_unit); + + erf_outputs = q_node_unit->Outputs(); + } else { + erf_outputs = erf_node_unit.Outputs(); + } + + parent_node_unit = erf_parent_node_unit; + child_node_unit = erf_child_node_unit; + return parent_node_unit != nullptr && child_node_unit != nullptr; +} + +std::unique_ptr GeluFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& /*logger*/) { + if (erf_node_unit.OpType() != "Erf") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const NodeUnit* div_node_unit = nullptr; + const NodeUnit* add_node_unit = nullptr; + const NodeUnit* dq_node_unit = nullptr; + const NodeUnit* q_node_unit = nullptr; + gsl::span erf_outputs; + + if (!GetErfParentAndChild(qnn_model_wrapper, erf_node_unit, node_to_node_unit, node_unit_to_qnn_node_group, + div_node_unit, add_node_unit, dq_node_unit, q_node_unit, erf_outputs)) { + return nullptr; + } + + // Erf must have a Div parent. + if (div_node_unit == nullptr || div_node_unit->OpType() != "Div") { + return nullptr; + } + + // Div must have 2 inputs + const auto& div_inputs = div_node_unit->Inputs(); + if (div_inputs.size() < 2) { + return nullptr; + } + + // Check second input of Div is sqrt(2) ≈ 1.4142 + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, div_inputs[1], static_cast(M_SQRT2))) { + return nullptr; + } + + // Erf must have an Add child consuming its output + if (add_node_unit == nullptr || add_node_unit->OpType() != "Add") { + return nullptr; + } + + // Add must have 2 inputs + const auto& add_inputs = add_node_unit->Inputs(); + if (add_inputs.size() < 2) { + return nullptr; + } + + // Check the other input node (e.g. not the Erf) is 1.0f + bool is_erf_first_input = (add_inputs[0].node_arg.Name() == erf_outputs[0].node_arg.Name()); + const auto& add_const_input = add_inputs[is_erf_first_input ? 1 : 0]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, add_const_input, 1.0f)) { + return nullptr; + } + + // Add must have a Mul child consuming its output + const auto& add_outputs = add_node_unit->Outputs(); + if (add_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul_node_unit = GetOnlyChildOfOutput(graph_viewer, *add_node_unit, add_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul_node_unit == nullptr || mul_node_unit->OpType() != "Mul") { + return nullptr; + } + + // Now check which pattern we have + const auto& root_input_name = div_inputs[0].node_arg.Name(); + const auto& mul_inputs = mul_node_unit->Inputs(); + + if (mul_inputs.size() < 2) { + return nullptr; + } + + // Try to match Pattern 1: root -> Mul(0.5) -> ... -> Mul + // In this case, one input to the final Mul should be from a Mul node + const NodeUnit* mul2_node_unit = nullptr; + + // Check if either input to mul_node_unit comes from a Mul node + for (size_t i = 0; i < 2; ++i) { + const auto& mul_input = mul_inputs[i]; + + const NodeUnit* producer_unit = GetParentOfInput(graph_viewer, *mul_node_unit, mul_input, + node_to_node_unit, node_unit_to_qnn_node_group); + if (producer_unit && producer_unit->OpType() == "Mul") { + const auto& mul2_inputs = producer_unit->Inputs(); + if (mul2_inputs.size() >= 2) { + bool has_root_input = (mul2_inputs[0].node_arg.Name() == root_input_name || + mul2_inputs[1].node_arg.Name() == root_input_name); + if (has_root_input) { + int root_index = (mul2_inputs[0].node_arg.Name() == root_input_name) ? 0 : 1; + const auto& mul_const_input = mul2_inputs[1 - root_index]; + + if (IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + mul2_node_unit = producer_unit; + break; + } + } + } + } + if (mul2_node_unit != nullptr) break; + } + + std::vector node_units; + const NodeUnit* final_mul_node_unit = nullptr; + + if (mul2_node_unit != nullptr) { + // Pattern 1: root -> Mul(0.5) -> ... -> Mul + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, mul2_node_unit, + mul_node_unit}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul2_node_unit, mul_node_unit}; + } + final_mul_node_unit = mul_node_unit; + } else { + // Try Pattern 2: root -> ... -> Mul -> Mul(0.5) + // Check if one input to mul_node_unit is root + bool has_root_input = (mul_inputs[0].node_arg.Name() == root_input_name || + mul_inputs[1].node_arg.Name() == root_input_name); + + if (!has_root_input) { + return nullptr; + } + + // mul_node_unit must have a Mul child consuming its output + const auto& mul_outputs = mul_node_unit->Outputs(); + if (mul_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul2_node_unit_pattern2 = GetOnlyChildOfOutput(graph_viewer, *mul_node_unit, mul_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul2_node_unit_pattern2 == nullptr || mul2_node_unit_pattern2->OpType() != "Mul") { + return nullptr; + } + + // Verify this final Mul has 2 inputs + const auto& mul2_inputs = mul2_node_unit_pattern2->Inputs(); + if (mul2_inputs.size() < 2) { + return nullptr; + } + + // Check the constant input is 0.5f + int mul_const_input_index = 0; + if (mul2_inputs[0].node_arg.Name() == mul_outputs[0].node_arg.Name()) { + mul_const_input_index = 1; + } + const auto& mul_const_input = mul2_inputs[mul_const_input_index]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + return nullptr; + } + + // Pattern 2 + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, + mul_node_unit, mul2_node_unit_pattern2}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul_node_unit, mul2_node_unit_pattern2}; + } + + final_mul_node_unit = mul2_node_unit_pattern2; + } + + // Validate on QNN + const NodeUnitIODef& root_input = div_inputs[0]; + const NodeUnitIODef& final_output = final_mul_node_unit->Outputs()[0]; + + if (Status status = ValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(std::move(node_units), &erf_node_unit); +} + +GeluFusion::GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit) + : node_units_(std::move(node_units)), target_node_unit_(target_node_unit) { +} + +Status GeluFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return ValidateOnQnn(qmw, node_units_, root_input, final_output); +} + +Status GeluFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return CreateOnQnn(qmw, node_units_, root_input, final_output); +} + +gsl::span GeluFusion::GetNodeUnits() const { + return gsl::span(node_units_.data(), node_units_.size()); +} + +const NodeUnit* GeluFusion::GetTargetNodeUnit() const { + return target_node_unit_; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate) { + assert(node_units.size() >= 4); + const auto& node_name = utils::GetUniqueName(*node_units[0]); + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(root_input, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(final_output, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + // Only add tensor wrappers if they don't already exist + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(root_input.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(final_output.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {root_input.node_arg.Name()}, + {final_output.node_arg.Name()}, + {}, + validate), + "Failed to add fused Gelu node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h new file mode 100644 index 0000000000000..508b1fca48a67 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of the Gelu pattern expanded into ONNX operators. +/// This fusion handles two patterns: +/// Pattern 1: +/// +-------Mul(0.5)---------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul ==> +/// (B=1.4142...) (1) +/// +/// Pattern 2: +/// +------------------------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +/// (B=1.4142...) (1) (0.5) +/// +/// Both patterns are translated into a QNN Gelu operator. +/// The contained NodeUnits can be of type SingleNode or QDQGroup (with Q-DQ nodes). +/// The second inputs to Div, Add, and Mul Node Units should be constant. +/// +class GeluFusion : public IQnnNodeGroup { + public: + GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GeluFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "GeluFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Gelu pattern. + /// If so, returns a IQnnNodeGroup that contains all the NodeUnits in the pattern. + /// + /// Used for validation and traverse/query the graph + /// Erf node unit that could be part of the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::vector node_units_; + const NodeUnit* target_node_unit_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 368caa518b7ba..4297801ce4cdc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -22,6 +22,7 @@ #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h" +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -83,6 +84,7 @@ static std::unordered_map> fusions = { {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Cast", {CastLoneQFusion::TryFusion}}, + {"Erf", {GeluFusion::TryFusion}}, {"Reshape", {Rank6ToRank5Fusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; @@ -119,9 +121,11 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except + // MatMul w/ LPBQ encodings, Erf and Reshape if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul" && + starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Reshape") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 10e1633e4b57d..7b77164a38545 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -226,14 +226,92 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, return nullptr; } - // parent must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (p_parent_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return p_parent_node_unit; + } + return nullptr; +} + +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node* p_parent_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_output : node->OutputDefs()) { + if (node_output->Name() == output.node_arg.Name()) { + p_parent_node = node; + break; + } + } + // break the loop if producer node of output is found + if (p_parent_node != nullptr) { + break; + } + } + + // return if the given output tensor is not produced by any node in the given node_unit + if (p_parent_node == nullptr) { + return nullptr; + } + + const Node& parent_node = *p_parent_node; + + if (graph_viewer.NodeProducesGraphOutput(parent_node)) { + // Node is producing a graph output + return nullptr; + } + + // First pass: count how many children consume this specific output + int child_count = 0; + const NodeUnit* p_child_node_unit = nullptr; + + for (auto edge = parent_node.OutputEdgesBegin(); edge != parent_node.OutputEdgesEnd(); ++edge) { + const Node& child_node = edge->GetNode(); + + // Check if this edge corresponds to the output we're looking for + bool is_matching_output = false; + for (auto child_input : child_node.InputDefs()) { + if (child_input->Name() == output.node_arg.Name()) { + is_matching_output = true; + break; + } + } + + if (!is_matching_output) { + continue; + } + + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + // Node is not in this GraphViewer return nullptr; } - return p_parent_node_unit; + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* current_child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(current_child_node_unit) != 0) { + return nullptr; + } + + // Store the child node unit and increment count + p_child_node_unit = current_child_node_unit; + child_count++; + + // If we found more than one child, return nullptr immediately + if (child_count > 1) { + return nullptr; + } } - return nullptr; + + // Return the child only if there's exactly one child + return (child_count == 1) ? p_child_node_unit : nullptr; } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 14e2a3f25e7db..b52cdd5fa3ec6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -51,5 +51,11 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map); +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index d8e58f0d0a170..ebe71c6ccfacd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -20,5 +20,9 @@ const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const Co return context.ep_.BufferManager(); } +const SplitKConfig& ComputeContext::GetSplitKConfig() { + return webgpu_context_.GetSplitKConfig(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 01cae1e337439..ed16f2f0a1345 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -152,6 +152,13 @@ class ComputeContext final { return webgpu_context_.Run(*this, program); } + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7dd3b50c656f4..7cbc7f6a4a821 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -13,7 +13,7 @@ namespace webgpu { // which are used in the MatMulWriteFnSource function. namespace { -void HanldeMaybeHaveBiasForGEMM(ShaderHelper& shader, +void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, const ShaderVariableHelper& output, bool has_bias, int c_components, @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, << output.SetByIndices("coords", "value") << "\n"; } +void HandleMatMulWithSplitK( + ShaderHelper& shader, + ProgramVariableDataType output_variable_type) { + shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + + // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups, + // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support + // atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16` + // with `atomicLoad` and `atomicCompareExchangeWeak`: + // 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`. + // 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`). + // 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`. + // 4. Convert the result of step 3 into `i32` values. + // 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak` + // and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not + // equal to `old_output_i32` (it is updated in another invocation). If the assignment fails + // we have to go to step 1 and repeat all the above steps. + switch (output_variable_type) { + case ProgramVariableDataType::Float32x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 4u; + for (var i = 0u; i < 4u; i++) { + let offset = offset0 + i; + while (true) { + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_f32 = bitcast(old_output_i32); + let new_output_f32 = old_output_f32 + value[i]; + let new_output_i32 = bitcast(new_output_f32); + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { + break; + } + } + } +)"; + break; + } + case ProgramVariableDataType::Float16x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 2u; + var vec2h_values : array; + vec2h_values[0] = value.xy; + vec2h_values[1] = value.zw; + for (var i = 0u; i < 2u; i++) { + let offset = offset0 + i; + while (true) { + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_vec2h = bitcast(old_output_i32); + let new_output_vec2h = old_output_vec2h + vec2h_values[i]; + let new_output_i32 = bitcast(new_output_vec2h); + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { + break; + } + } + } +)"; + break; + } + default: + break; + } +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet, - bool is_channels_last) { + bool is_channels_last, + bool use_split_k, + ProgramVariableDataType output_variable_type) { shader.AdditionalImplementation() << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n"; @@ -134,8 +200,17 @@ void MatMulWriteFnSource(ShaderHelper& shader, shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n" << " var value = valueIn; \n"; - if (is_gemm) { - HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); + if (use_split_k) { + // Set output when MatMul is performed with Split-K. + // When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` + // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we + // still need to handle `has_bias` (and `is_channels_last` in the future) in + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled."); + ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); + HandleMatMulWithSplitK(shader, output_variable_type); + } else if (is_gemm) { + HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); } else { HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last); } @@ -159,9 +234,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, uint32_t tile_inner, bool split_k, uint32_t split_dim_inner) { - ORT_UNUSED_PARAMETER(split_k); - ORT_UNUSED_PARAMETER(split_dim_inner); - const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); std::string write_data_to_sub_a_vec4_snippet = @@ -208,14 +280,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let tileCol = i32(local_id.x);\n" << " let globalRow = i32(global_id.y) * rowPerThread;\n" << " let globalCol = i32(global_id.x);\n" - << " let batch = i32(global_id.z);\n" - << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" - << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" - << " var kStart = 0;\n" << " var acc: array, rowPerThread>;\n"; + if (split_k) { + // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into + // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from + // `kSplitK * i32(global_id.z)`. + // + // For example: considering computing Y = (X * W + B) in one workgroup. + // Let kSplitK = 2, B = [d1, d2] + // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2 + // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 + // [b2 b2] C2 ] + // [b2 b2] + // [c2 c2] + // [c2 c2]] + // + // With Split-K: + // 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2] + // [d1, d2]] + // 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side) + // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) + // Workgroup3: compute (C1 * C2) + // In each workgroup: + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - When the computation in each workgroup is completed, add the result to Y with several + // atomic built-in functions in `HandleMatMulWithSplitK()`. + shader.MainFunctionBody() + << "const kSplitK = " << split_dim_inner << ";\n" + << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" + << " var kStart = kSplitK * i32(global_id.z);\n" + + // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // the index of split-k instead of batch. + << " let batch = 0;\n" + << " let batchIndices = 0u;\n"; + } else { + shader.MainFunctionBody() + << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" + << " var kStart = 0;\n" + << " let batch = i32(global_id.z);\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); + } + // Loop over shared dimension. shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index ed4cf997d2f00..7075debeb9952 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet = "", - bool is_channels_last = false); + bool is_channels_last = false, + bool use_split_k = false, + ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4); // The two following functions are used to generate shader code for vec4 and scalar. // It is used in GEMM, Matmul, and Conv. diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index cf4b9d3fae2d2..55c2c5773cc1f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -161,14 +161,14 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const auto* bias = context.Input(2); inputs.push_back(bias); } - auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false); - return context.RunProgram(program); + return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); } -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, - const TensorShape& input_a_reshape, - const TensorShape& input_b_reshape) { +Status ComputeMatMul(ComputeContext* context, + const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + const TensorShape& input_a_reshape, + const TensorShape& input_b_reshape) { const auto* a = inputs[0]; const auto* b = inputs[1]; bool has_bias = inputs.size() > 2; @@ -226,31 +226,97 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); - const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / - (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last}; - program - .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last) + ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components); + const Tensor* bias = has_bias ? inputs[2] : nullptr; + bool use_bias_in_matmul = has_bias; + uint32_t split_dim_inner = 1; + + const SplitKConfig& split_k_config = context->GetSplitKConfig(); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + if (need_split_k) { + ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); + ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); + ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + + // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. + const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); + ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); + + // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set + // `bias` again in `MatMulProgram`. + use_bias_in_matmul = false; + + // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the + // number of splits along `dim_inner`. + // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize + // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; + + // The output should be declared in atomic types in `MatMulProgram` for the use of atomic + // built-in functions. + output.is_atomic = true; + } + + MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; + matmul_program + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z); + .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) + .AddOutput(std::move(output)); - if (has_bias) { + if (use_bias_in_matmul) { auto bias_components = is_channels_last ? components : 1; - const auto* bias = inputs[2]; TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + } + + return context->RunProgram(matmul_program); +} + +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + const Tensor* bias, + Tensor* output, + const TensorShape& output_shape_vec4) { + const bool has_bias = bias != nullptr; + + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + constexpr uint32_t bias_components = 4; + MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); + + const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); + const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); + + // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for + // this program. + const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; + const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; + + // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar + // instead of vec4, while use `output_shape_vec4` directly as the output shape. + const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components); + program.CacheHint(has_bias) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) + .SetDispatchGroupSize(dispatch_x); + + if (has_bias) { + const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)}); } + return program; } diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 8ab8c3a6ba2d0..0b65827be7f17 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,9 +14,14 @@ namespace onnxruntime { namespace webgpu { -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, - const TensorShape& input_a_reshape = TensorShape(), - const TensorShape& input_b_reshape = TensorShape()); +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + const TensorShape& input_a_reshape = TensorShape(), + const TensorShape& input_b_reshape = TensorShape()); + +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + const Tensor* bias, + Tensor* output, + const TensorShape& output_shape_vec4); class MatMul final : public WebGpuKernel { public: diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 585f8f1e011c4..4daabe8246aa7 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -14,25 +14,77 @@ namespace webgpu { Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + const bool need_split_k = NeedSplitK(); + ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; + if (need_split_k) { + // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in + // functions on it, so we need below information to correctly compute the index on the output. + output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + } + const auto& output = shader.AddOutput("output", output_usage); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_); + MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source( + shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims, + /*transA = */ false, /*transB = */ false, /*alpha = */ 1.f, /*need_handle_matmul = */ true, + /*output_components = */ 4, /*tile_inner = */ 32, need_split_k, split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } return Status::OK(); } +bool MatMulProgram::NeedSplitK() const { + return split_dim_inner_ > 1; +} + +Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + // Handle bias with `MatMulWriteFnSource()`. + // Here `use_split_k` is false because we just initialize `output` with bias. + // `use_split_k` is true only when we do the actual MatMul with Split-K. + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + MatMulWriteFnSource( + shader, output, has_bias_, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); + + shader.MainFunctionBody() << R"( + let output_components = 4; + let output_id = i32(global_idx); + + let dim_a_outer = i32(uniforms.dim_a_outer); + let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; + if (output_id >= dim_a_outer * dim_b_outer) { + return; + } + + let output_row = output_id / dim_b_outer; + let output_col = output_id % dim_b_outer; + let output_batch = 0; + let output_value = output_value_t(); + mm_write(output_batch, output_row, output_col, output_value); +)"; + + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 767fdd8802e5b..143ba61c99e13 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -13,24 +13,48 @@ namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"}, - activation_(activation), - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), - is_channels_last_(is_channels_last) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, {"dim_inner", ProgramUniformVariableDataType::Uint32}); + bool NeedSplitK() const; + private: const Activation activation_; const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; bool is_channels_last_ = false; + uint32_t split_dim_inner_ = 1; +}; + +// The program to initialize the output with 0 or bias before doing MatMul with Split-K. In Split-K, +// we set the output values with `atomicLoad` and `atomicCompareExchangeWeak` instead of a direct +// assignment (see the function `HandleMatMulWithSplitK()` in `gemm_utils.cc`), so we must initialize +// the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. +class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { + public: + explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias) + : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, + has_bias_(has_bias) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..77fa46cb87518 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -200,8 +200,7 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); - return context.RunProgram(program); + return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc index 9e934e9eb5db7..aa0caab39a88e 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/framework/op_kernel_info.h" #include namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h index f5d2585bb9b45..fad7d3d145bc6 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h @@ -1,11 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include -#include "core/providers/webgpu/webgpu_kernel.h" +#include + +#include "core/common/status.h" #pragma once namespace onnxruntime { + +class OpKernelInfo; + namespace webgpu { + enum class ActivationKind { None, Relu, diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5447966b91aa7..b08649cbd5d5b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -472,7 +472,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } ss << ": array<"; if (is_atomic) { - if (output->type_ == ProgramVariableDataType::Float32) { + if (output->type_ == ProgramVariableDataType::Float32 || output->type_ == ProgramVariableDataType::Float16x4 || output->type_ == ProgramVariableDataType::Float32x4) { ss << "atomic"; // emulate float atomic via i32 } else if (output->type_ == ProgramVariableDataType::Uint32) { ss << "atomic"; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 9af9cd455b5a4..28decb076951e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -910,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 1ead7b3a005bb..bd7dae75f2e2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,12 +5,14 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" #include "core/common/common.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/webgpu_utils.h" #if defined(ENABLE_PIX_FOR_WEBGPU_EP) #include "core/providers/webgpu/webgpu_pix_frame_generator.h" @@ -171,6 +173,13 @@ class WebGpuContext final { Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: enum class TimestampQueryType { None = 0, @@ -268,6 +277,8 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; + std::optional split_k_config_; + // profiling TimestampQueryType query_type_; wgpu::QuerySet query_set_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3df194217933e..e0b84fef51f1f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -878,9 +878,9 @@ std::vector> WebGpuExecutionProvider::GetCapa const auto& inputs = node.InputDefs(); const auto& outputs = node.OutputDefs(); - // Current implementation does not support mask_index(input[3]), past(input[5]) and past_seq_len(input[6]) + // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6]) FALLBACK_TO_CPU_IF_EXIST_INPUT(3); - FALLBACK_TO_CPU_IF_EXIST_INPUT(5); + FALLBACK_TO_CPU_IF_EXIST_INPUT(4); FALLBACK_TO_CPU_IF_EXIST_INPUT(6); // Current implementation does not support present(output[1]) diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 53b96dfe7a346..568d29a96cb88 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,5 +21,64 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } +SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { + SplitKConfig config = {}; + + if (adapter_info.vendor == std::string_view{"intel"}) { + if (adapter_info.architecture == std::string_view{"xe-2lpg"} || + adapter_info.architecture == std::string_view{"xe-2hpg"} || + adapter_info.architecture == std::string_view{"xe-lpg"} || + adapter_info.architecture == std::string_view{"gen-12hp"}) { + config.enable_split_k_ = true; + + // Below thresholds are only verified on the above Intel GPUs without any regressions. The + // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be + // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more + // atomic calls for each output value. + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + } + } + return config; +} + +bool SplitKConfig::UseSplitK( + bool is_vec4, + ActivationKind activation_kind, + uint64_t batch_size, + bool is_channels_last, + uint32_t dim_a_outer, + uint32_t dim_b_outer, + uint32_t dim_inner) const { + if (!enable_split_k_) { + return false; + } + + bool use_split_k = true; + + // TODO: support the cases below. + use_split_k &= activation_kind == ActivationKind::None; + use_split_k &= is_vec4; + use_split_k &= batch_size == 1; + // Now `is_channels_last` is only supported because we only generate vec4 shaders in + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + use_split_k &= is_channels_last; + + // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and + // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and + // `dim_inner)` as the metric to decide whether to use Split-K or not. + use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); + use_split_k &= (dim_inner <= max_dim_inner_with_split_k_); + use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); + + return use_split_k; +} + +uint32_t SplitKConfig::GetSplitDimInner() const { + return split_dim_inner_; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 86eb57f99f3b3..d45b9bf4dd119 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -7,6 +7,8 @@ #include "core/common/common.h" #include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" +#include "core/providers/webgpu/webgpu_external_header.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { @@ -89,5 +91,24 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +class SplitKConfig { + public: + static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + + bool UseSplitK( + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_channels_last, uint32_t dim_a_outer, + uint32_t dim_b_outer, uint32_t dim_inner) const; + + uint32_t GetSplitDimInner() const; + + private: + bool enable_split_k_ = false; + uint32_t split_dim_inner_ = 0; + uint32_t min_dim_inner_with_split_k_ = 0; + uint32_t max_dim_inner_with_split_k_ = 0; + float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; +}; + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 0b927075402fe..a29fbdb91e79f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape"); NodeAttrHelper helper(node); + const int32_t local_window_size = helper.Get("local_window_size", -1); const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0); const uint32_t num_heads = helper.Get("num_heads", 0); @@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b | | +-------------------------------> Lesser <---------------------Transpose (1,0) | - 1 ---> Where <--- finfo_min (minimum value of FP32) + 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32) | attention_bias */ - const std::vector mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length, - 1); - std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) + - "_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) + - "_" + std::to_string(past_sequence_length); - emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant( - ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape, - std::vector({batch_size, num_heads, qkv_sequence_length, past_sequence_length})); + emscripten::val value_int_one_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1}); + + std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length}; + common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand"); + emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call( + "expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options); emscripten::val cumsum_options = emscripten::val::object(); cumsum_options.set("label", node.Name() + "_range_of_mask_shape"); @@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1); std::string pre_neq_right_data_range_name = - "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length); + "webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length); emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant( ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range, std::vector({qkv_sequence_length})); @@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val neq_right = model_builder.GetBuilder().call("transpose", expanded_neq_right, transpose_options); - common_options.set("label", node.Name() + "_/GQA/attn_mask/condition"); - emscripten::val condition = + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1"); + emscripten::val condition_1 = model_builder.GetBuilder().call("lesser", neq_left, neq_right, common_options); + emscripten::val condition = condition_1; + // For local window size not equal to -1, new attention mask pattern for applying sliding window + /* + condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true) + | | + | Lesser <--- local_window_size + | | + LogicalAnd <----------------- condition_2 + | + new attn_mask + */ + if (local_window_size != -1) { + // Cast condition + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast"); + emscripten::val casted_condition_1 = + model_builder.GetBuilder().call("cast", condition_1, emscripten::val("int32"), common_options); + + cumsum_options = emscripten::val::object(); + cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum"); + cumsum_options.set("exclusive", true); + cumsum_options.set("reversed", true); + emscripten::val neq_left_2 = model_builder.GetBuilder().call( + "cumulativeSum", casted_condition_1, gsl::narrow(3), cumsum_options); + + emscripten::val local_window_size_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1}); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2"); + emscripten::val condition_2 = + model_builder.GetBuilder().call("lesser", neq_left_2, local_window_size_constant, common_options); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and"); + condition = model_builder.GetBuilder().call( + "logicalAnd", condition_1, condition_2, common_options); + } + emscripten::val value_one_constant = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1}); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 9bc6c8d0a96a1..6c6c589ffcad4 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -755,6 +755,21 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap(); + + auto kvps = std::make_unique(); + for (const auto& kv : config_options_map) { + kvps->Add(kv.first.c_str(), kv.second.c_str()); + } + + *out = kvps.release(); + return nullptr; + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { if (count_or_bytes == 0) { *out = nullptr; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 14f0892687ad1..4d4dea9cb444c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3642,7 +3642,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st void InferenceSession::ShrinkMemoryArenas(gsl::span arenas_to_shrink) { for (auto& alloc : arenas_to_shrink) { - auto status = static_cast(alloc.get())->Shrink(); + auto status = static_cast(alloc.get())->Shrink(); if (!status.IsOK()) { LOGS(*session_logger_, WARNING) << "Unable to shrink arena: " << alloc->Info().ToString() diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 546b11ae580d5..394f69bb15b19 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2293,6 +2293,12 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char* cfg->initial_growth_chunk_size_bytes = static_cast(arena_config_values[i]); } else if (strcmp(arena_config_keys[i], "max_power_of_two_extend_bytes") == 0) { cfg->max_power_of_two_extend_bytes = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "use_cuda_mempool") == 0) { + cfg->use_cuda_mempool = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "cuda_mempool_release_threshold") == 0) { + cfg->cuda_mempool_release_threshold = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "cuda_mempool_bytes_to_keep_on_shrink") == 0) { + cfg->cuda_mempool_bytes_to_keep_on_shrink = static_cast(arena_config_values[i]); } else { std::ostringstream oss; oss << "Invalid key found: " << arena_config_keys[i]; @@ -4231,6 +4237,7 @@ static constexpr OrtApi ort_api_1_to_24 = { // End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::TensorTypeAndShape_HasShape, + &OrtApis::KernelInfo_GetConfigEntries, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f016bb3215330..c0e4d32ac0167 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -751,4 +751,7 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + +ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + } // namespace OrtApis diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index dcb4a495c23c5..045dc98a3501e 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -527,6 +527,7 @@ void Im2col::operator()( template struct Im2col; template struct Im2col; +template struct Im2col; template void Im2col::operator()( diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 66b8467bda335..51a405613bea1 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -1,5 +1,8 @@ # ONNXRuntime EP Context Model Generation with Weight Sharing +> [!NOTE] +> This tool is deprecated. Please use the public ONNX Runtime Python APIs to compile models with resource sharing. Refer to the example Python script at the end of this document. + [EP context with weight sharing design doc](https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html#epcontext-with-weight-sharing) OnnxRuntime provides the ep_weight_sharing_ctx_gen tool to automate the weight-sharing workflow. This tool handles the entire process. This tool is specifically designed for weight sharing scenarios, streamlining the EPContext model generation process. @@ -13,6 +16,23 @@ Example: ./ep_weight_sharing_ctx_gen -e qnn -i "soc_model|60 htp_graph_finalizat Options: -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'. + -p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options. + + Example JSON configuration that selects plugin EP devices via name: + { + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_name": "example_plugin_ep", + "default_ep_options": { "key": "value" } + } + + Example JSON configuration that selects plugin EP devices via index: + { + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_device_indices": [ 0 ], + "default_ep_options": { "key": "value" } + } -v: Show verbose information. -C: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. @@ -36,3 +56,49 @@ Options: -h: help ``` + +# Example: Use Python APIs to compile models with resource sharing +Use of the public ORT Python APIs is now recommended for compiling models with resource (e.g., "weight") sharing. +The following snippet shows an example that compiles two models using an example plugin EP. + +```Python +import onnxruntime +import os + +def main(): + ep_name = "example_ep" + ep_lib_path = "example_plugin_ep.dll" + + onnxruntime.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Find one or more EP devices that correspond to the EP of interest. + # In this example, we pick the first one. + ep_device = next((d for d in onnxruntime.get_ep_devices() if d.ep_name == ep_name), None) + + # These are the names/paths to the input and output models. + input_models = ["model_0.onnx", "model_1.onnx"] + output_models = ["model_0_ctx.onnx", "model_1_ctx.onnx"] + + num_models = len(input_models) + session_options = onnxruntime.SessionOptions() + provider_options = {} # Empty for this example + + # Set option that tells EP to share resources (e.g., weights) across sessions. + session_options.add_session_config_entry("ep.share_ep_contexts", "1") + session_options.add_provider_for_devices([ep_device], provider_options) + + # Compile individual models + for i in range(len(input_models)): + if i == num_models - 1: + # Tell EP that this is the last compiling session that will be sharing resources. + session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1") + + model_compiler = onnxruntime.ModelCompiler( + session_options, + input_models[i], + embed_compiled_data_into_model=False, + ) + model_compiler.compile_to_file(output_models[i]) + + onnxruntime.unregister_execution_provider_library(ep_name) +``` diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index cecf5575d42a5..15bce163ba16a 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -4,6 +4,7 @@ #include "command_args_parser.h" #include +#include #include #include #include @@ -21,6 +22,7 @@ #include #include +#include "nlohmann/json.hpp" #include "test_configuration.h" namespace onnxruntime { @@ -35,6 +37,23 @@ namespace qnnctxgen { "\n" "Options:\n" "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.\n" + "\t-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options.\n" + "\n" + "\t Example JSON configuration that selects plugin EP devices via EP name:\n" + "\t {\n" + "\t \"ep_library_registration_name\": \"example_plugin_ep\",\n" + "\t \"ep_library_path\": \"example_plugin_ep.dll\",\n" + "\t \"selected_ep_name\": \"example_plugin_ep\",\n" + "\t \"default_ep_options\": { \"key\": \"value\" }\n" + "\t }\n" + "\n" + "\t Example JSON configuration that selects plugin EP devices via index:\n" + "\t {\n" + "\t \"ep_library_registration_name\": \"example_plugin_ep\",\n" + "\t \"ep_library_path\": \"example_plugin_ep.dll\",\n" + "\t \"selected_ep_device_indices\": [ 0 ],\n" + "\t \"default_ep_options\": { \"key\": \"value\" }\n" + "\t }\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" @@ -58,6 +77,7 @@ namespace qnnctxgen { "\n" "\t-h: help\n"); } + #ifdef _WIN32 static const ORTCHAR_T* delimiter = L","; #else @@ -110,9 +130,63 @@ static bool ParseSessionConfigs(const std::string& configs_string, return true; } +static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfig& config_out) { + using json = nlohmann::json; + bool success = true; + + ORT_TRY { + std::ifstream ifs{json_file_path}; + if (!ifs) { + std::cerr << "ERROR: Failed to open plugin EP configuration file at path: " + << json_file_path.c_str() << std::endl; + return false; + } + + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + PluginEpConfig config{}; + const auto parsed_json = json::parse(content); + + // required keys + parsed_json.at("ep_library_registration_name").get_to(config.ep_library_registration_name); + parsed_json.at("ep_library_path").get_to(config.ep_library_path); + + // optional keys + config.default_ep_options = parsed_json.value("default_ep_options", {}); + config.selected_ep_name = parsed_json.value("selected_ep_name", {}); + config.selected_ep_device_indices = + parsed_json.value("selected_ep_device_indices", {}); + + if (config.selected_ep_name.empty() == config.selected_ep_device_indices.empty()) { + std::cerr << "ERROR: Plugin EP configuration must specify exactly one of 'selected_ep_name' " + << "or 'selected_ep_device_indices'" << std::endl; + return false; + } + + config_out = std::move(config); + return success; + } + ORT_CATCH(const json::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + std::string kExampleValidJsonStr = + "{\n" + " \"ep_library_registration_name\": \"example_plugin_ep\",\n" + " \"ep_library_path\": \"/path/to/example_plugin_ep.dll\",\n" + " \"selected_ep_name\": \"example_plugin_ep\"\n" + "}"; + + success = false; + std::cerr << "ERROR: JSON parse error: " << e.what() << std::endl; + std::cerr << "This is an example valid JSON configuration:\n" + << kExampleValidJsonStr.c_str() << std::endl; + }); + } + return success; +} + /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:p:o:u:i:C:vh"))) != -1) { switch (ch) { case 'e': if (!CompareCString(optarg, ORT_TSTR("qnn"))) { @@ -128,6 +202,20 @@ static bool ParseSessionConfigs(const std::string& configs_string, return false; } break; + case 'p': { +#ifdef _MSC_VER + std::string plugin_ep_config_file_path = ToUTF8String(optarg); +#else + std::string plugin_ep_config_file_path = optarg; +#endif + PluginEpConfig plugin_ep_config{}; + if (!ParsePluginEpConfig(plugin_ep_config_file_path, plugin_ep_config)) { + return false; + } + + test_config.machine_config.plugin_ep_config = std::move(plugin_ep_config); + break; + } case 'v': test_config.run_config.f_verbose = true; break; @@ -202,6 +290,11 @@ static bool ParseSessionConfigs(const std::string& configs_string, argc -= optind; argv += optind; + if (argc == 0) { + std::cerr << "ERROR: Did not specify model paths" << std::endl; + return false; + } + ParsePaths(argv[0], test_config.model_file_paths); return true; diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json b/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json new file mode 100644 index 0000000000000..f8967d1831582 --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json @@ -0,0 +1,6 @@ +{ + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_name": "example_plugin_ep", + "default_ep_options": { "option_key": "option_value" } +} diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc index 18abe1eb131d8..3f2cda26fe9df 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -10,6 +10,7 @@ // onnx dependencies #include "onnx/onnx_pb.h" +#include #include using namespace onnxruntime; @@ -81,6 +82,72 @@ static void UpdateEpContextModel(const std::vector> } } +using PluginEpLibraryRegistrationHandle = std::unique_ptr>; + +static PluginEpLibraryRegistrationHandle RegisterPluginEpLibrary(Ort::Env& env, + const std::string& ep_library_registration_name, + const std::basic_string& ep_library_path) { + env.RegisterExecutionProviderLibrary(ep_library_registration_name.c_str(), ep_library_path); + + auto unregister_ep_library = [&env, registration_name = ep_library_registration_name](void* p) { + if (p == nullptr) { + return; + } + + ORT_TRY { + env.UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Failed to unregister EP library with name '" << registration_name << "': " + << e.what() << std::endl; + }); + } + }; + + // Set `handle_value` to something not equal to nullptr. The particular value doesn't really matter. + // We are just using the unique_ptr deleter to unregister the EP library. + void* const handle_value = reinterpret_cast(0x1); + return PluginEpLibraryRegistrationHandle{handle_value, unregister_ep_library}; +} + +static bool SetPluginEpSessionOptions(Ort::Env& env, Ort::SessionOptions& session_options, + const qnnctxgen::PluginEpConfig& config, + PluginEpLibraryRegistrationHandle& plugin_ep_library_registration_handle) { + auto lib_registration_handle = RegisterPluginEpLibrary(env, config.ep_library_registration_name, + ToPathString(config.ep_library_path)); + + std::vector ep_devices = env.GetEpDevices(); + std::vector selected_ep_devices{}; + + if (!config.selected_ep_device_indices.empty()) { + for (const auto idx : config.selected_ep_device_indices) { + if (idx >= ep_devices.size()) { + std::cerr << "ERROR: Selected EP device index is out of range (max is " << ep_devices.size() - 1 << "): " + << idx << std::endl; + return false; + } + + selected_ep_devices.push_back(ep_devices[idx]); + } + } else { + std::copy_if(ep_devices.begin(), ep_devices.end(), std::back_inserter(selected_ep_devices), + [&selected_ep_name = std::as_const(config.selected_ep_name)](Ort::ConstEpDevice ep_device) { + return ep_device.EpName() == selected_ep_name; + }); + } + + if (selected_ep_devices.empty()) { + std::cerr << "ERROR: No EP devices were selected" << std::endl; + return false; + } + + session_options.AppendExecutionProvider_V2(env, selected_ep_devices, config.default_ep_options); + plugin_ep_library_registration_handle = std::move(lib_registration_handle); + + return true; +} + #ifdef _WIN32 int real_main(int argc, wchar_t* argv[]) { #else @@ -98,6 +165,7 @@ int real_main(int argc, char* argv[]) { Ort::Env env(logging_level, "ep_weight_sharing"); ORT_TRY { + PluginEpLibraryRegistrationHandle plugin_ep_library_registration_handle{}; Ort::SessionOptions so; so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); // Set default session option to dump EPContext model with non-embed mode @@ -136,7 +204,14 @@ int real_main(int argc, char* argv[]) { // The context binary file generated later includes all graphs from previous models { std::string provider_name_ = test_config.machine_config.provider_type_name; - if (provider_name_ == onnxruntime::kQnnExecutionProvider) { + + if (const auto& plugin_ep_config = test_config.machine_config.plugin_ep_config; plugin_ep_config.has_value()) { + if (!SetPluginEpSessionOptions(env, so, *plugin_ep_config, plugin_ep_library_registration_handle)) { + std::cerr << "ERROR: Failed to initialize session for plugin EP " + << test_config.machine_config.plugin_ep_config->ep_library_path << std::endl; + return 1; + } + } else if (provider_name_ == onnxruntime::kQnnExecutionProvider) { #ifdef USE_QNN so.AppendExecutionProvider("QNN", provider_options); #else diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index 198d03211f561..6dfb7b60ddc27 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -14,8 +15,25 @@ namespace onnxruntime { namespace qnnctxgen { +// Configuration for initializing the dynamic plugin EP infrastructure. +struct PluginEpConfig { + std::string ep_library_registration_name{}; + std::string ep_library_path{}; + + // Note: Exactly one of `selected_ep_name` or `selected_ep_device_indices` should be set. + // An empty value for either means it is unset. + + // Specifies the EP devices matching this EP name as the selected EP devices. + std::string selected_ep_name{}; + // Specifies the selected EP devices by their indices. + std::vector selected_ep_device_indices{}; + + std::unordered_map default_ep_options{}; +}; + struct MachineConfig { std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; + std::optional plugin_ep_config = std::nullopt; }; struct RunConfig { diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 9ded9d2bfeac0..5a50998af584f 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -339,7 +339,7 @@ struct StreamMock : public Stream { #ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); + StreamAwareBFCArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); CheckStats(&a, 0, 0, 0, 0); OrtDevice tmp; @@ -451,7 +451,7 @@ TEST(BFCArenaTest, TestExtendStrategy) { 0, true, config}; auto allocator = CreateAllocator(device_info); size_t block_size = 1 << 20; // 1MB - BFCArena& a = *static_cast(allocator.get()); + auto& a = *allocator; a.Alloc(block_size); AllocatorStats stats; a.GetStats(&stats); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 8f6ed6f55c11a..aca345fccdc01 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -388,10 +388,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; // Now run - st = session_object.Run(run_options, *io_binding.get()); - - std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; - ASSERT_TRUE(st.IsOK()); + ASSERT_STATUS_OK(session_object.Run(run_options, *io_binding)); if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || (output_device && output_device->Type() == OrtDevice::GPU)) { @@ -402,21 +399,19 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, auto& rtensor = outputs.front().Get(); auto element_type = rtensor.DataType(); auto& shape = rtensor.Shape(); - std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); + Tensor cpu_tensor(element_type, shape, cpu_alloc); #ifdef USE_CUDA - st = GetProviderInfo_CUDA().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif #ifdef USE_ROCM - st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif #ifdef USE_WEBGPU - st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif ASSERT_TRUE(st.IsOK()); OrtValue ml_value; - ml_value.Init(cpu_tensor.release(), - DataTypeImpl::GetType(), - DataTypeImpl::GetType()->GetDeleteFunc()); + Tensor::InitOrtValue(std::move(cpu_tensor), ml_value); VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); #endif } else { @@ -2230,7 +2225,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { auto cuda_alloc = session_object.GetAllocator(mem_info); AllocatorStats alloc_stats; - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); #ifdef ENABLE_TRAINING // In training builds, initializers are allocated using the Reserve() call which // will not cause an arena extension @@ -2250,7 +2245,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { RunOptions run_options_1; RunModel(session_object, run_options_1); - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); // The arena would have made 2 more extensions as part of servicing memory requests within Run() // 1) - To take the solitary feed to cuda memory @@ -2274,7 +2269,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { "gpu:0")); RunModel(session_object, run_options_2); - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); // The arena would have made no extensions in this Run() as the freed memory after the first Run() // will be re-used diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 9bdc0898c81c1..ed2b98e5280b5 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -405,7 +405,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // One reserve call should have been made (for allocating memory for the sole initializer in the model) ASSERT_EQ(1, alloc_stats.num_reserves); - // This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareArena instance + // This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareBFCArena instance ASSERT_EQ(1, alloc_stats.num_allocs); } } diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 2d5c3a43ee8ed..37c3825101ba4 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -78,7 +78,18 @@ TEST_F(ShapeInferenceTest, BasicTest) { namespace { struct MyCustomKernelWithOptionalInput { - MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { + MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) { + Ort::ConstKernelInfo k_info(info); + + Ort::KeyValuePairs kvp = k_info.GetConfigEntries(); + + EXPECT_NE(nullptr, kvp.GetValue("session.inter_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.inter_op.allow_spinning")); + + EXPECT_NE(nullptr, kvp.GetValue("session.intra_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.intra_op.allow_spinning")); + + EXPECT_EQ(nullptr, kvp.GetValue("__not__exist__")); } OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const { @@ -143,6 +154,8 @@ TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { SessionOptions sess_opts; sess_opts.inter_op_param.thread_pool_size = 1; sess_opts.intra_op_param.thread_pool_size = 1; + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.inter_op.allow_spinning", "0")); + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.intra_op.allow_spinning", "0")); InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2}; ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains))); diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 8960898f036fc..2c9377d48f0c4 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include "test_configuration.h" #include "strings_helper.h" @@ -24,6 +26,7 @@ #include "absl/flags/usage.h" #include "absl/flags/usage_config.h" #include "absl/flags/reflection.h" +#include "absl/strings/str_split.h" static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTestConfig() { static onnxruntime::perftest::PerformanceTestConfig default_config{}; @@ -149,8 +152,10 @@ ABSL_FLAG(std::string, C, "", "Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "[Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n"); ABSL_FLAG(std::string, R, "", "Allows user to register custom op by .so or .dll file."); -ABSL_FLAG(bool, A, DefaultPerformanceTestConfig().run_config.enable_cpu_mem_arena, "Disables memory arena."); -ABSL_FLAG(bool, M, DefaultPerformanceTestConfig().run_config.enable_memory_pattern, "Disables memory pattern."); +ABSL_FLAG(bool, A, !DefaultPerformanceTestConfig().run_config.enable_cpu_mem_arena, "Disables memory arena."); +ABSL_FLAG(std::string, shrink_arena_between_runs, "", "When arena is enabled call Shrink for specified devices 'cpu:0;gpu:0'"); +ABSL_FLAG(std::string, enable_cuda_mempool, "", "When cuda is enabled use CudaMempoolArena with params 'pool_release_threshold;bytes_to_keep_on_shrink'"); +ABSL_FLAG(bool, M, !DefaultPerformanceTestConfig().run_config.enable_memory_pattern, "Disables memory pattern."); ABSL_FLAG(bool, s, DefaultPerformanceTestConfig().run_config.f_dump_statistics, "Shows statistics result, like P75, P90. If no result_file provided this defaults to on."); ABSL_FLAG(bool, v, DefaultPerformanceTestConfig().run_config.f_verbose, "Shows verbose information."); ABSL_FLAG(bool, I, DefaultPerformanceTestConfig().run_config.generate_model_input_binding, "Generates tensor input binding. Free dimensions are treated as 1 unless overridden using -f."); @@ -261,10 +266,34 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a } // -M - test_config.run_config.enable_memory_pattern = absl::GetFlag(FLAGS_M); + test_config.run_config.enable_memory_pattern = !absl::GetFlag(FLAGS_M); // -A - test_config.run_config.enable_cpu_mem_arena = absl::GetFlag(FLAGS_A); + test_config.run_config.enable_cpu_mem_arena = !absl::GetFlag(FLAGS_A); + + // --shrink_arena_between_runs + if (test_config.run_config.enable_cpu_mem_arena) { + auto shrink_spec = absl::GetFlag(FLAGS_shrink_arena_between_runs); + test_config.run_config.run_config_entries.emplace( + kOrtRunOptionsConfigEnableMemoryArenaShrinkage, + std::move(shrink_spec)); + } + + // --enable_cuda_mempool + { + auto cuda_mempool_spec = absl::GetFlag(FLAGS_enable_cuda_mempool); + if (!cuda_mempool_spec.empty()) { + // Split the string with ';' separator in two parts + std::vector parts = absl::StrSplit(cuda_mempool_spec, ';'); + if (parts.size() == 2U) { + test_config.run_config.cuda_mempool_arena_config = { + std::move(parts[0]), std::move(parts[1])}; + } else { + std::cerr << "Invalid format for --enable_cuda_mempool. " + << "Expected format : " << std::endl; + } + } + } // -e { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 0f2da07c69d85..cb40a9beafeee 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -10,9 +10,11 @@ #include #include #include +#include #include #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/providers/cuda/cuda_provider_options.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" #include "core/providers/dnnl/dnnl_provider_options.h" #include @@ -45,12 +47,15 @@ RunTiming OnnxRuntimeTestSession::Run() { auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); Ort::RunOptions run_options; + for (const auto& kv : run_config_entries_) { + run_options.AddConfigEntry(kv.first.c_str(), kv.second.c_str()); + } + RunTiming timing; if (CUDA == device_memory_name_) { run_options.AddConfigEntry(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "1"); Ort::IoBinding io_binding(session_); - const OrtMemoryInfo* mem_info; - Ort::ThrowOnError(Ort::GetApi().AllocatorGetInfo(allocator_, &mem_info)); + auto mem_info = allocator_.GetInfo(); for (size_t i = 0; i < input_names_.size(); ++i) { io_binding.BindInput(input_names_[i], input[i]); @@ -76,7 +81,11 @@ RunTiming OnnxRuntimeTestSession::Run() { OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device& rd, const PerformanceTestConfig& performance_test_config, const TestModelInfo& m) - : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { + : rand_engine_(rd()), + input_names_(m.GetInputCount()), + input_names_str_(m.GetInputCount()), + input_length_(m.GetInputCount()), + run_config_entries_(performance_test_config.run_config.run_config_entries) { Ort::SessionOptions session_options; // Add EP devices if any (created by plugin EP) @@ -218,8 +227,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - Ort::CUDAProviderOptions cuda_options; + Ort::CUDAProviderOptions cuda_options; const char* config_val = nullptr; switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: @@ -249,6 +258,24 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } cuda_options.Update(provider_options); + if (performance_test_config.run_config.cuda_mempool_arena_config) { + // Enable and configure cuda_mempool arena + const size_t release_threshold = + static_cast(std::atoll(performance_test_config.run_config.cuda_mempool_arena_config->release_threshold.c_str())); + const size_t bytes_to_keep_on_shrink = + static_cast(std::atoll(performance_test_config.run_config.cuda_mempool_arena_config->bytes_to_keep.c_str())); + // Create a map of properties for the arena configuration + std::unordered_map arena_config_map = { + {"use_cuda_mempool", 1U}, + {"cuda_mempool_bytes_to_keep_on_shrink", bytes_to_keep_on_shrink}, + {"cuda_mempool_release_threshold", release_threshold}, + }; + // Must be kept alive while session is alive + Ort::ArenaCfg cuda_arena_cfg(arena_config_map); + cuda_mempool_arena_cfg_ = std::move(cuda_arena_cfg); + (*cuda_options).default_memory_arena_cfg = cuda_mempool_arena_cfg_; + } + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #else ORT_THROW("CUDA is not supported in this build\n"); @@ -1014,7 +1041,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); } custom_allocator_ = Ort::Allocator(session_, memory_info); - allocator_ = custom_allocator_; + // Switch to custom + allocator_ = Ort::UnownedAllocator(custom_allocator_); // free dimensions are treated as 1 if not overridden transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index ada467824ca18..743db63b7b43c 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -42,9 +42,11 @@ class OnnxRuntimeTestSession : public TestSession { Ort::Session session_{nullptr}; std::mt19937 rand_engine_; std::uniform_int_distribution dist_; - OrtAllocator* allocator_ = Ort::AllocatorWithDefaultOptions(); + Ort::AllocatorWithDefaultOptions default_allocator_; // Note: custom_allocator_, if used, must outlive the `Ort::Value`s allocated with it in test_inputs_ and outputs_. + // and must be declared before them to ensure it is destroyed after them. Ort::Allocator custom_allocator_{nullptr}; + Ort::UnownedAllocator allocator_{default_allocator_}; std::vector> test_inputs_; std::vector outputs_; std::vector output_names_; @@ -56,9 +58,11 @@ class OnnxRuntimeTestSession : public TestSession { const int input_length_; std::string provider_name_; std::string device_memory_name_; // Device memory type name to use from the list in allocator.h + const std::unordered_map& run_config_entries_; #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) cudaStream_t stream_; // Device stream if required by IO bindings #endif + Ort::ArenaCfg cuda_mempool_arena_cfg_{nullptr}; }; } // namespace perftest diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index c982a8daadc9d..1d8ad77096ef3 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -63,6 +64,7 @@ struct RunConfig { bool set_denormal_as_zero{false}; std::basic_string ep_runtime_config_string; std::unordered_map session_config_entries; + std::unordered_map run_config_entries; std::map free_dim_name_overrides; std::map free_dim_denotation_overrides; std::string intra_op_thread_affinities; @@ -75,6 +77,11 @@ struct RunConfig { bool compile_ep_context{false}; std::basic_string compile_model_path; bool compile_binary_embed{false}; + struct CudaMempoolArenaConfig { + std::string release_threshold; + std::string bytes_to_keep; + }; + std::optional cuda_mempool_arena_config; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 8382258bf39b4..0847c15ba7cc6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,9 +3,10 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) || defined(USE_WEBGPU) #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -39,12 +40,13 @@ please add the EP to the excluded_providers list. void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - int opset = 11) { + int opset = 11, + float rel_error = 0.002f) { std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -84,7 +86,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (inputs.size() >= 4) tester->AddInput(szNames[3], input_shapes[3], inputs[3]); - tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, 0.002f, 0.0f); + tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, rel_error, 0.0f); std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported @@ -424,6 +426,118 @@ TEST(ConvFp16Test, Conv2D_2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); + } + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11); +} + +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B_float32(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + vector B = FloatsToMLFloat16s(B_float32); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); + } + sum += B[n].ToFloat(); + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + // Using a higher relative error threshold for the Linux arm64 bots + constexpr float rel_error = 0.02f; + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11, rel_error); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11, rel_error); +} + TEST(ConvFp16Test, Conv2D_Bias_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index c98d9e28b2f46..8155ac41318f6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,6 +254,595 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } +TEST(ConvIntegerTest, WithoutPadding_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {12, 16, + 24, 28}); + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9}); + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -8, + -8, -9, + -7, -6, + -6, -7, + -5, -4, + -4, -5}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_3D_u8s8) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect."; + } + + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + std::vector w_dims{1, 1, 2, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9, + -9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1, 1, 1}); + std::vector y_dims{1, 1, 4, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9, + 11, 24, 28, 15, + 28, 60, 68, 36, + 40, 84, 92, 48, + 23, 48, 52, 27, + 29, 60, 64, 33, + 64, 132, 140, 72, + 76, 156, 164, 84, + 41, 84, 88, 45, + 19, 39, 41, 21, + 41, 84, 88, 45, + 47, 96, 100, 51, + 25, 51, 53, 27}); + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + // Exercise the (stride_w = 2) path inside Math::Im2col. + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + // Exercise the (stride_w > 2) path inside Math::Im2col. + test.Run(); +} + +TEST(ConvIntegerTest, WithoutPadding_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, -2, + 3, -4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {27, -31, + -39, 43}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, -2, + 3, -4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {4, -11, 18, -9, + -14, 27, -31, 15, + 20, -39, 43, -21, + 14, -23, 26, -9}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -8, + -8, -9, + -7, -6, + -6, -7, + -5, -4, + -4, -5}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_3D_s8s8) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect."; + } + + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{1, 1, 2, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9, + -9, -9, + -9, -9}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9, + 11, 24, 28, 15, + 28, 60, 68, 36, + 40, 84, 92, 48, + 23, 48, 52, 27, + 29, 60, 64, 33, + 64, 132, 140, 72, + 76, 156, 164, 84, + 41, 84, 88, 45, + 19, 39, 41, 21, + 41, 84, 88, 45, + 47, 96, 100, 51, + 25, 51, 53, 27}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + // Exercise the (stride_w = 2) path inside Math::Im2col. + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + // Exercise the (stride_w > 2) path inside Math::Im2col. + + test.Run(); +} + +TEST(ConvIntegerTest, WithoutPadding_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, 2, + 3, 4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {-5, 5, + 5, -5}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, 2, + 3, 4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {-4, 5, -6, -9, + 14, -5, 5, 15, + -20, 5, -5, -21, + -14, 9, -10, -9}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {11, 12, + 12, 11, + 13, 14, + 14, 13, + 15, 16, + 16, 15}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {11, 12, 11, + 12, 13, 12, + 11, 12, 11}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {11, 12, 11, + 12, 13, 12, + 11, 12, 11}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + + test.Run(); +} + TEST(ConvIntegerTest, NoXZeroPoint) { OpTester test("ConvInteger", 10); std::vector x_dims{1, 1, 3, 3}; diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 7c84aefa1c01f..4efbb8cfd5c19 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/graph/constants.h" #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" using namespace std; @@ -23,7 +24,7 @@ struct ConvOpAndTestAttributes { void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, optional epsilon = optional(), @@ -535,6 +536,103 @@ TEST(ConvTest, Conv2D_AutoPad2) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + // Fill X and W + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + // Fill X, W and B + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + sum += B[n]; + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + // Conv10 TEST(ConvTest, Conv3D_1) { ConvOpAndTestAttributes attrs = { diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc new file mode 100644 index 0000000000000..70c7a5b2bcdcb --- /dev/null +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -0,0 +1,306 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_CUDA + +#include +#include + +#include + +#include "core/common/inlined_containers.h" // InlinedVector +#include "core/framework/allocator.h" // OrtMemoryInfo, IAllocator, AllocatorStats, onnxruntime::CUDA +#include "core/framework/execution_provider.h" +#include "core/framework/stream_handles.h" // onnxruntime::Stream (interface) +#include "core/providers/cuda/cuda_provider_options.h" +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_provider_factory_creator.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +// --------- Helpers --------- + +static bool IsCudaMemPoolSupported() { + int ort_cuda_rt_version = 0; + cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_rt_version < 11020) { + return false; + } + + int ort_cuda_driver_version = 0; + cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_driver_version < 11020) { + return false; + } + + // Check if the driver version supports the runtime version + if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) { + return false; + } + + if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) { + return false; + } + + // Creating a cuda mempool in some pipelines fails with + // CUDA failure 801: operation not supported ; GPU=0 ; hostname=af14bbb1c000000 ; + // Even though CUDA version may be 12.8 possibly due to the driver. + cudaMemPoolProps props{}; + // Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist + // even though is present in some internet docs. + props.allocType = cudaMemAllocationTypePinned; + props.handleTypes = cudaMemHandleTypeNone; // local to process + props.location.type = cudaMemLocationTypeDevice; // Device memory + props.location.id = 0; // test device 0 + cudaMemPool_t pool; + auto cuda_error = cudaMemPoolCreate(&pool, &props); + if (cuda_error != cudaSuccess) { + return false; + } + cuda_error = cudaMemPoolDestroy(pool); + + return true; +} + +static ::cudaStream_t NewCudaStream() { + ::cudaStream_t s{}; + const cudaError_t st = ::cudaStreamCreate(&s); + EXPECT_EQ(st, cudaSuccess); + return s; +} + +static void DestroyCudaStream(::cudaStream_t s) { + if (s) (void)::cudaStreamDestroy(s); +} + +static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { + ASSERT_NE(p, nullptr); + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, static_cast(value), bytes, s)); +} + +// --------- Test parameters --------- + +struct MPArenaParams { + uint64_t release_threshold = 1ull << 20; // 1 MB (recommended in allocator docs) + size_t bytes_to_keep = 4ull << 20; // 4 MB (small trim target for tests) +}; + +OrtArenaCfg CreateArenaCfgFromParams(const MPArenaParams& params) { + OrtArenaCfg cfg; + cfg.initial_chunk_size_bytes = 0; // Make BFCArena for CUDAPinned not to allocate anything here + cfg.use_cuda_mempool = 1; // Key switch + cfg.cuda_mempool_release_threshold = params.release_threshold; + cfg.cuda_mempool_bytes_to_keep_on_shrink = params.bytes_to_keep; + return cfg; +} + +std::unique_ptr CreateCudaExecutionProvider(OrtArenaCfg& arena_cfg) { + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.device_id = 0; // single-device tests + cuda_options.default_memory_arena_cfg = &arena_cfg; + cuda_options.do_copy_in_default_stream = true; + cuda_options.use_tf32 = false; + if (auto factory = CudaProviderFactoryCreator::Create(&cuda_options)) + return factory->CreateProvider(); + return nullptr; +} + +AllocatorPtr GetCudaMempoolArena(IExecutionProvider& cuda_ep) { + auto allocators = cuda_ep.CreatePreferredAllocators(); + EXPECT_EQ(allocators.size(), 2u); + const auto& mem_info = allocators[0]->Info(); + EXPECT_EQ("CUDAMemPoolArena", mem_info.name); + return allocators[0]; +} + +// --------- Minimal test Stream adapter --------- +// +// Adapts a cudaStream_t to ORT's Stream interface. +// If your Stream interface has additional pure virtuals on the work branch, +// add trivial overrides here (returning defaults / no-ops) so tests compile. +class TestCudaStream final : public onnxruntime::Stream { + public: + TestCudaStream(::cudaStream_t s, const OrtDevice& device) : Stream(s, device) {} + + ~TestCudaStream() { + DestroyCudaStream(static_cast<::cudaStream_t>(GetHandle())); + } + + void* GetHandle() const { + // ORT expects GetHandle() to return the native handle (cast to void*). + return Stream::GetHandle(); + } +}; + +// --------- Test fixture --------- + +class CudaMempoolArenaTest : public ::testing::Test { + protected: + void SetUp() override { + if (!IsCudaMemPoolSupported()) { + GTEST_SKIP() << "CUDA memory pools not supported on this device/driver."; + } + + const auto& logger = onnxruntime::logging::LoggingManager::DefaultLogger(); + orig_severity_ = logger.GetSeverity(); + orig_verbosity_ = logger.VLOGMaxLevel(); + logging::LoggingManager::SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + logging::LoggingManager::SetDefaultLoggerVerbosity(0); + cuda_ep_ = CreateCudaExecutionProvider(arena_cfg_); + cuda_ep_->SetLogger(&logger); + arena_ = GetCudaMempoolArena(*cuda_ep_); + mem_info_ = arena_->Info(); + } + + void TearDown() override { + arena_.reset(); + cuda_ep_.reset(); + ::cudaDeviceSynchronize(); + logging::LoggingManager::SetDefaultLoggerSeverity(orig_severity_); + logging::LoggingManager::SetDefaultLoggerVerbosity(orig_verbosity_); + } + + logging::Severity orig_severity_; + int orig_verbosity_; + OrtArenaCfg arena_cfg_ = CreateArenaCfgFromParams(MPArenaParams()); + std::unique_ptr cuda_ep_; + AllocatorPtr arena_; + OrtMemoryInfo mem_info_; +}; + +// --------- Tests --------- + +TEST_F(CudaMempoolArenaTest, AllocAndFree_OnDefaultStream) { + const size_t kBytes = 1 << 20; // 1 MB + void* p = arena_->Alloc(kBytes); + ASSERT_NE(p, nullptr); + + // default (legacy) stream 0 + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, 0xCD, kBytes, /*stream=*/0)); + arena_->Free(p); + + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + onnxruntime::AllocatorStats stats{}; + arena_->GetStats(&stats); + EXPECT_GE(stats.num_allocs, 1u); +} + +TEST_F(CudaMempoolArenaTest, AllocOnTwoStreams_OrderedFree) { + const size_t kBytes = 2 << 20; // 2 MB + + ::cudaStream_t s0 = NewCudaStream(); + ::cudaStream_t s1 = NewCudaStream(); + { + TestCudaStream ort_s0(s0, mem_info_.device); + TestCudaStream ort_s1(s1, mem_info_.device); + + void* p0 = arena_->AllocOnStream(kBytes, &ort_s0); + void* p1 = arena_->AllocOnStream(kBytes, &ort_s1); + ASSERT_NE(p0, nullptr); + ASSERT_NE(p1, nullptr); + + TouchDevice(p0, kBytes, s0, 0x11); + TouchDevice(p1, kBytes, s1, 0x22); + + // Enqueue ordered frees (no sync needed here). + arena_->Free(p0); + arena_->Free(p1); + + // Ensure queued frees completed on each stream. + ASSERT_EQ(::cudaSuccess, ::cudaStreamSynchronize(s0)); + ASSERT_EQ(::cudaSuccess, ::cudaStreamSynchronize(s1)); + + // Destroy streams here + } + + ASSERT_EQ(::cudaSuccess, ::cudaGetLastError()); +} + +TEST_F(CudaMempoolArenaTest, Shrink_TrimsPool_And_AllowsFurtherUse) { + const size_t kBytes = 2 << 20; + + InlinedVector ptrs; + for (size_t i = 0; i < ptrs.capacity(); ++i) { + void* p = arena_->Alloc(kBytes); + ASSERT_NE(p, nullptr); + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, 0xEF, kBytes, /*stream=*/0)); + ptrs.push_back(p); + } + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + for (void* p : ptrs) { + arena_->Free(p); + } + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + // Trim and sanity-check future allocations still work. + auto* arena_cast = IArena::SafeArenaCast(arena_.get()); + ASSERT_STATUS_OK(arena_cast->Shrink()); + + void* p_check = arena_->Alloc(kBytes); + ASSERT_NE(p_check, nullptr); + arena_->Free(p_check); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +TEST_F(CudaMempoolArenaTest, Reserve_DelegatesToAlloc) { + const size_t kBytes = 512 * 1024; + void* p = arena_->Reserve(kBytes); + ASSERT_NE(p, nullptr); + arena_->Free(p); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +// Validates allocator dtor guarantees completion of queued frees even when +// streams are destroyed prior to allocator destruction. +TEST_F(CudaMempoolArenaTest, Destructor_CompletesQueuedFrees_EvenIfStreamDestroyed) { + const size_t kBytes = 1 << 20; + ::cudaStream_t s = NewCudaStream(); + + { + auto cuda_prov = CreateCudaExecutionProvider(arena_cfg_); + cuda_prov->SetLogger(&onnxruntime::logging::LoggingManager::DefaultLogger()); + auto alloc = GetCudaMempoolArena(*cuda_ep_); + { + TestCudaStream ort_s(s, mem_info_.device); + + InlinedVector ptrs; + for (size_t i = 0; i < ptrs.capacity(); ++i) { + void* p = alloc->AllocOnStream(kBytes, &ort_s); + ASSERT_NE(p, nullptr); + TouchDevice(p, kBytes, s); + ptrs.push_back(p); + } + + for (void* p : ptrs) { + alloc->Free(p); + } + + // Destroy the stream *before* the frees have a chance to run. + } + + // arena goes out of scope here; its destructor must: + // - sync known streams (best-effort), + // - device-wide synchronize as a safety net, + // - then trim and destroy the pool. + } + + ASSERT_EQ(::cudaSuccess, ::cudaGetLastError()); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +} // namespace test +} // namespace onnxruntime + +#endif // USE_CUDA diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc new file mode 100644 index 0000000000000..e28cf00aa070b --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +namespace { + +// Helper function to build GELU Pattern 1: root -> Mul -> Div -> Erf -> Add -> Mul +// Pattern 1: +// +-------Mul(0.5)---------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul ==> +// (B=1.4142...) (1) +GetTestModelFn BuildGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Create Mul(0.5) branch: input * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* mul_half_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, half_initializer}, {mul_half_output}); + + // Create main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Final Mul: (add_output) * (mul_half_output) + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {add_output, mul_half_output}, {output}); + }; +} + +// Helper function to build GELU Pattern 2: Mul(0.5) after the main sequence +// Pattern 2: +// +------------------------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +// (B=1.4142...) (1) (0.5) +GetTestModelFn BuildGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Mul with input: input * add_output + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, add_output}, {mul_output}); + + // Final Mul with 0.5: mul_output * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {mul_output, half_initializer}, {output}); + }; +} + +// Helper function to build QDQ GELU Pattern 1 +template +GetTestQDQModelFn BuildQDQGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) -> Q + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +// Helper function to build QDQ GELU Pattern 2 +template +GetTestQDQModelFn BuildQDQGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + // Main branch: DQ -> Div -> Q + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +// Test GELU Pattern 1 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern1TestCase(input_def), + BuildQDQGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +// Test GELU Pattern 2 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern2TestCase(input_def), + BuildQDQGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/helper.py b/onnxruntime/test/python/helper.py index 2a2c3fc9b4532..99960640fe92e 100644 --- a/onnxruntime/test/python/helper.py +++ b/onnxruntime/test/python/helper.py @@ -1,4 +1,5 @@ import os +import sys def get_name(name): @@ -13,3 +14,14 @@ def get_name(name): if os.path.exists(res): return res raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") + + +def get_shared_library_filename_for_platform(base_name): + if sys.platform.startswith("win"): + return base_name + ".dll" + + if sys.platform.startswith("darwin"): + return "lib" + base_name + ".dylib" + + # Else, assume linux + return "lib" + base_name + ".so" diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index e46cdb4f98850..c60307d3c0116 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -10,7 +10,7 @@ import onnx from autoep_helper import AutoEpTestCase -from helper import get_name +from helper import get_name, get_shared_library_filename_for_platform import onnxruntime as onnxrt from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation @@ -53,6 +53,52 @@ def test_compile_with_files_prefer_npu_policy(self): self.assertTrue(os.path.exists(output_model_path)) self.unregister_execution_provider_library(ep_name) + def test_compile_shared_resources_plugin_ep(self): + """ + Test compiling two example models using weight sharing (via example plugin EP) + """ + ep_lib_path = get_shared_library_filename_for_platform("example_plugin_ep") + try: + ep_lib_path = get_name(ep_lib_path) + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + ep_device = next((d for d in onnxrt.get_ep_devices() if d.ep_name == ep_name), None) + self.assertIsNotNone(ep_device) + + input_models = [get_name("add_mul_add.onnx"), get_name("mul_1.onnx")] + output_models = [ + os.path.join(self._tmp_dir_path, "output_model_0_ctx.onnx"), + os.path.join(self._tmp_dir_path, "output_model_1_ctx.onnx"), + ] + + num_models = len(input_models) + session_options = onnxrt.SessionOptions() + + # Set option that tells EP to share resources (e.g., weights) across sessions. The example plugin EP + # doesn't actually do anything special, but we do this to test the API + session_options.add_session_config_entry("ep.share_ep_contexts", "1") + session_options.add_provider_for_devices([ep_device], {}) + + # Compile individual models + for i in range(num_models): + if i == num_models - 1: + # Tell EP that this is the last session that will be sharing resources. + session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1") + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_models[i], + embed_compiled_data_into_model=False, + ) + model_compiler.compile_to_file(output_models[i]) + self.assertTrue(os.path.exists(output_models[i])) + + self.unregister_execution_provider_library(ep_name) + def test_compile_with_ep_selection_delegate(self): """ Tests compiling a model (to/from files) using an EP selection delegate callback. diff --git a/onnxruntime/test/shared_lib/test_runtime_path.cc b/onnxruntime/test/shared_lib/test_runtime_path.cc index 621d006a8659a..f004c96041ef6 100644 --- a/onnxruntime/test/shared_lib/test_runtime_path.cc +++ b/onnxruntime/test/shared_lib/test_runtime_path.cc @@ -21,7 +21,11 @@ bool IsDirectorySeparator(PathChar c) { } } // namespace +#if !defined(_AIX) TEST(GetRuntimePathFromSharedLibraryTest, Basic) { +#else +TEST(GetRuntimePathFromSharedLibraryTest, DISABLED_Basic) { +#endif const auto* runtime_path_cstr = OrtTestGetSharedLibraryRuntimePath(); ASSERT_NE(runtime_path_cstr, nullptr); diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index f1599b6843fb5..3f34e7ae37538 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -46,12 +46,12 @@ jobs: clean: true submodules: none - - task: UsePythonVersion@0 - displayName: Use Python 3.12 - inputs: - versionSpec: 3.12 - ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: - architecture: arm64 + - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: + # Only need to install Python on x64 agents as Python is pre-installed on arm64 agents + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: set-version-number-variables-step.yml - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml index 166b03f6b55e1..749b6093cf9d3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml +++ b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml @@ -28,12 +28,13 @@ parameters: displayName: Architecture type: string #default: 'linux-x64' - + steps: - task: PythonScript@0 inputs: scriptSource: 'filePath' scriptPath: 'tools/ci_build/linux_java_copy_strip_binary.py' + pythonInterpreter: 'python3' arguments: >- --binary-dir $(Build.BinariesDirectory) --build-config ${{parameters.buildConfig}} @@ -47,4 +48,3 @@ steps: inputs: targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' artifactName: 'drop-${{parameters.artifactName}}' - diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 907563cb77242..0bc0a94fdd6e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -53,6 +53,7 @@ stages: - task: UsePythonVersion@0 inputs: versionSpec: '3.13' + architecture: arm64 addToPath: true - script: | @@ -77,7 +78,7 @@ stages: cd temp find $(Build.ArtifactStagingDirectory) -name '*.zip' -exec unzip {} \; rm -rf $(Build.ArtifactStagingDirectory)/*; - find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; + find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; ls -l mv *.tgz $(Build.ArtifactStagingDirectory) displayName: 'Unzip Signed Files and Repackage to TGZ'