From c00f69041ce4d8a4fecb7fbbc7d9f7ccf1591fda Mon Sep 17 00:00:00 2001 From: "Chunye Wang@AMD" Date: Tue, 8 Jul 2025 01:25:20 +0800 Subject: [PATCH 01/10] Add a new ORT API `GetSessionOptionConfigEntries` (#25277) ### Description Add a new ORT API `GetSessionOptionConfigEntries`. ### Motivation and Context #24887 allows plugin-EPs to interface with ORT using a binary stable interface. #24445 allows an EP to handle the extraction of EP options from the session option configurations. For an EP like VitisAI EP to comply with the requirements, it is necessary for a plugin-EPs to access all config entries in a session option. ```c++ OrtKeyValuePairs * kvps = nullptr; auto status = GetSessionOptionConfigEntries(session_option, &kvps); if(status) { throw status; } std::unique_ptr config_entries(kvps, ort_api.ReleaseKeyValuePairs); const char* const* keys = nullptr; const char* const* values = nullptr; size_t num_keys = 0; // Get keys and values from the config entries Ort::GetApi().GetKeyValuePairs(config_entries.get(), &keys, &values, &num_keys); for (size_t i = 0; i < num_keys; ++i) { // process keys[i] and values[i] } ``` --- .../onnxruntime/core/session/onnxruntime_c_api.h | 12 ++++++++++++ onnxruntime/core/session/abi_session_options.cc | 15 +++++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 2 ++ onnxruntime/core/session/ort_apis.h | 2 ++ 4 files changed, 31 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7cdcbb3bc76bf..86c0b60db2bc4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6074,6 +6074,18 @@ struct OrtApi { * \since Version 1.23 */ ORT_API2_STATUS(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); + + /** \brief Get Session configuration entries. + * + * \param[in] options The session options. + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance. + * + * An OrtKeyValuePairs instance containing all session configuration entries. + * Note: the user should call OrtApi::ReleaseKeyValuePairs. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); }; /* diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 7a17423112144..3df6d37d63794 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -278,6 +278,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetSessionConfigEntry, _In_ const OrtSessionOptions API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out) { + API_IMPL_BEGIN + if (options == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "options is nullptr"); + } + auto& config_options = options->value.config_options.GetConfigOptionsMap(); + auto kvps = std::make_unique(); + for (auto& kv : config_options) { + kvps->Add(kv.first.c_str(), kv.second.c_str()); + } + *out = reinterpret_cast(kvps.release()); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, _In_ const OrtValue* val) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2551fbc8b6099..e7f60fd48a14f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3632,6 +3632,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSharedAllocator, &OrtApis::GetTensorData, + + &OrtApis::GetSessionOptionsConfigEntries, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4c4ab07493237..cbacbfce0740d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -690,4 +690,6 @@ ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDe _In_ OrtDeviceMemoryType mem_type); ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); + +ORT_API_STATUS_IMPL(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); } // namespace OrtApis From dafa7f9365b1ed9c9e27a9806fef6e7275c5d9c6 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 7 Jul 2025 11:23:46 -0700 Subject: [PATCH 02/10] fix webgpu dequantize_linear ut (#25271) --- .../core/providers/webgpu/quantization/quantize_linear.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index 0305049e9b789..e7736c3f3afac 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -129,9 +129,6 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions(); int max_components = GetMaxComponents(x_size); - if (max_components != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DequantizeLinear: components must be 4, but got ", max_components); - } // scaler - single scaler for all elements bool per_layer = x_scale_rank == 0 || (x_scale_rank == 1 && x_scale->Shape()[0] == 1); From 7b2e367454b4bfe7d3ee727730c2fcd7a669671d Mon Sep 17 00:00:00 2001 From: jing-bao Date: Tue, 8 Jul 2025 03:19:36 +0800 Subject: [PATCH 03/10] [webgpu] Optimize DP4AMatMulNBitsSmallMProgram for intel (#25192) ### Description This PR optimizes the Intel GPU path for the `DP4AMatMulNBitsSmallMProgram` by tuning `tile_size` and `tile_size_k_vec`. ### Motivation and Context With this change, we achieved >8% performance boost on Intel iGPUs (Xe-LP and Xe2-LPG) for phi-4-mini-accuracy4 model. --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 66364f7dab96e..02d02e824b357 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -596,6 +596,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t tile_size_k_vec = 16; uint32_t tile_size = 32; + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + tile_size_k_vec = 32; + tile_size = 4; + } + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits, has_zero_points}; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; mul_program.SetWorkgroupSize(128); From 8645dd55006c1dfdf9cdc98a3b6e203f53d69fce Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Mon, 7 Jul 2025 13:54:10 -0700 Subject: [PATCH 04/10] Migrate stale bot workflow to updateStaleIssues.yml policy (#21660) --- .github/policies/updateStaleIssues.yml | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 .github/policies/updateStaleIssues.yml diff --git a/.github/policies/updateStaleIssues.yml b/.github/policies/updateStaleIssues.yml new file mode 100644 index 0000000000000..2a041a50f93a7 --- /dev/null +++ b/.github/policies/updateStaleIssues.yml @@ -0,0 +1,65 @@ +name: Update Stale Issues +description: Update stale issues +resource: repository +configuration: + resourceManagementConfiguration: + scheduledSearches: + - description: Apply stale label to open, unassigned issues that have not been updated in the last 30 days + frequencies: + - daily: + time: 15:00 + filters: + - isIssue + - isOpen + - isNotAssigned + - isNotLabeledWith: + label: contributions welcome + - isNotLabeledWith: + label: documentation + - isNotLabeledWith: + label: feature request + - isNotLabeledWith: + label: regression + - noActivitySince: + days: 30 + actions: + - addReply: + reply: "Applying stale label due to no activity in 30 days" + - addLabel: + label: stale + - description: Close open, unassigned issues labeled stale that have not been updated in the last 30 days + frequencies: + - daily: + time: 15:00 + filters: + - hasLabel: + label: stale + - isIssue + - isOpen + - isNotAssigned + - noActivitySince: + days: 30 + actions: + - addReply: + reply: "Closing issue due to no activity in 30 days" + - closeIssue + eventResponderTasks: + - description: Remove stale label if open stale issue is commented on + if: + - payloadType: Issue_Comment + - hasLabel: + label: stale + then: + - removeLabel: + label: stale + - description: Re-open stale issue if closed stale issue is commented on + if: + - payloadType: Issue_Comment + - and: + - not: + isOpen + - hasLabel: + label: stale + then: + - reopenIssue + From 15e75af26bb9adc709f051522a1f0ebceb3fee52 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 7 Jul 2025 15:51:14 -0700 Subject: [PATCH 05/10] Add RotaryEmbeddings(23) - CUDA (#25178) Follow up #24980 Fix https://github.com/microsoft/onnxruntime/issues/24556 Add ONNX RotaryEmbedding(23) following https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding. The PR uses contrib op RotaryEmbedding implementation under the hood. The main difference between this op and the contrib op is that the position_ids in ONNX RotaryEmbedding is optional. When it's not provided, cos_cache and sin_cache should be 3d. --- docs/OperatorKernels.md | 1 + .../providers/cpu/llm/rotary_embedding.cc | 12 +- .../providers/cuda/cuda_execution_provider.cc | 6 + .../providers/cuda/llm/rotary_embedding.cc | 85 +++++++++ .../providers/cuda/llm/rotary_embedding.h | 26 +++ .../cuda/llm/rotary_embedding_impl.cu | 169 ++++++++++++++++++ .../cuda/llm/rotary_embedding_impl.h | 51 ++++++ 7 files changed, 344 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/llm/rotary_embedding.cc create mode 100644 onnxruntime/core/providers/cuda/llm/rotary_embedding.h create mode 100644 onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu create mode 100644 onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb6e7c38ccb4..1ffcabee8cc10 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -828,6 +828,7 @@ Do not modify directly.* |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc index 616374eee6ff1..f1b0d5850ab1f 100644 --- a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc @@ -72,13 +72,13 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* cos_data; const T* sin_data; int cache_offset; - if (position_ids_format == 0) { - cache_offset = (b * sequence_length + s) * half_rotary_emb_dim; - } else { - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) - const int position_id = static_cast(position_ids[b * sequence_length + s]); - cache_offset = position_id * half_rotary_emb_dim; + // position_ids_format == 0 means position_ids is nullptr + // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) + int b_s_index = b * sequence_length + s; + if (position_ids_format != 0) { + b_s_index = static_cast(position_ids[b_s_index]); } + cache_offset = b_s_index * half_rotary_emb_dim; cos_data = cos_cache + cache_offset; sin_data = sin_cache + cache_offset; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d9acb9ccdc30f..1f4c9fcdbc073 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1490,6 +1490,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding); #endif @@ -2480,6 +2483,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc new file mode 100644 index 0000000000000..f259c6021a82e --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/llm/rotary_embedding_helper.h" +#include "core/providers/cuda/llm/rotary_embedding.h" +#include "core/providers/cuda/llm/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::rotary_embedding_helper; + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kOnnxDomain, \ + 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* cos_cache = context->Input(1); + const Tensor* sin_cache = context->Input(2); + const Tensor* position_ids = context->Input(3); // Optional, can be nullptr + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + num_heads, + rotary_embedding_dim, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + + // Handle optional position_ids - pass nullptr if position_ids is null + const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data() : nullptr; + + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids_data, + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.rotary_embedding_dim, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock, + parameters.transposed); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h new file mode 100644 index 0000000000000..09bd2e7dfde15 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads; + int rotary_embedding_dim; + int interleaved; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..eda049dfae9a8 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu @@ -0,0 +1,169 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include "core/providers/cuda/llm/rotary_embedding_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // BxSx(H/2) or Mx(H/2) + const T* sin_cache, // BxSx(H/2) or Mx(H/2) + const int64_t* position_ids, // (0) or BxS + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.y; + const int s = blockIdx.x; + const int n = blockIdx.z; + + const int i = threadIdx.x; + + if (i >= head_size) { + return; + } + + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; + + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + + // Cache is (M, H/2) + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + int cache_offset; + + // position_ids_format == 0 means position_ids is nullptr + // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) + int b_s_index = b * sequence_length + s; + if (position_ids_format != 0) { + b_s_index = static_cast(position_ids[b_s_index]); + } + cache_offset = b_s_index * half_rotary_embedding_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_rotary_embedding_dim; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + // Semantic meaning of the strides: + // int in_head_stride = sequence_length * head_size; + // int out_head_stride = sequence_length * head_size; + // in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + // out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + // Simplify to: + in_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; + out_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; + } else { + // input is in bshn format + // int in_head_stride = head_size; + // int out_head_stride = head_size; + // Simplify to: + in_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; + out_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); + + int tpb = (head_size + 31) / 32 * 32; + + const dim3 block(tpb); + const dim3 grid(sequence_length, batch_size, num_heads); + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float* output, const float* input, + const int64_t* position_ids, const float* cos_cache, + const float* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format); + +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half* output, const half* input, + const int64_t* position_ids, const half* cos_cache, + const half* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, + const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length, + const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, const int max_threads_per_block, + const bool is_input_bnsh_format); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h new file mode 100644 index 0000000000000..2389c0596014e --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool is_input_bnsh_format); + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + int4 in_strides, + int4 out_strides); + +} // namespace cuda +} // namespace onnxruntime From f0097fcb30ba79fc1118b77c22c4e8501a195c6d Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:51:51 -0700 Subject: [PATCH 06/10] Exclude EPContext Op from Common Subexpression Elimination graph optimization (#25296) In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will produce the same results. Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation, which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob. Therefore, we exclude EPContext op from CSE. --- .../optimizer/common_subexpression_elimination.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index 471e4ee7c03a3..8f78f9c3b6cc7 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -421,7 +421,13 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i for (NodeIndex node_index : node_topology_list) { Node* node = graph.GetNode(node_index); - if (node == nullptr) + + // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will + // produce the same results. + // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation, + // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob. + // Therefore, EPContext nodes are excluded from this process. + if (node == nullptr || node->OpType() == "EPContext") continue; ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); @@ -471,7 +477,12 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i for (NodeIndex node_index : node_topology_list) { Node* node = graph.GetNode(node_index); - if (node == nullptr) + // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will + // produce the same results. + // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation, + // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob. + // Therefore, EPContext nodes are excluded from this process. + if (node == nullptr || node->OpType() == "EPContext") continue; bool node_output_replaced = false; From 6d28e2d21615f491d782c30d3b91295a70738c89 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 7 Jul 2025 16:59:00 -0700 Subject: [PATCH 07/10] [webgpu] support smooth softmax for non-FA GQA implementation (#25285) ### Description support smooth softmax for non-FA GQA implementation This change depends on: - #25269 Work items: - [x] support smooth softmax - [x] support bias - [x] support head sink (per-head smooth softmax) The following will not be included in this PR: - support for FlashAttention - support sliding window --- .../cpu/bert/group_query_attention.cc | 2 + .../cpu/bert/group_query_attention_helper.h | 18 ++++++ .../contrib_ops/webgpu/bert/attention.cc | 62 +++++++++++++------ .../contrib_ops/webgpu/bert/attention.h | 14 +++-- .../webgpu/bert/attention_common.h | 3 +- .../webgpu/bert/flash_attention.cc | 2 + .../webgpu/bert/group_query_attention.cc | 24 +++++-- 7 files changed, 93 insertions(+), 32 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 9c7530f0126bb..a912bd6e6b43c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -54,6 +54,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* sin_cache = context->Input(8); const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); + const Tensor* head_sink = context->Input(11); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -73,6 +74,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, + head_sink, parameters)); const int batch_size = parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 338c34acb3cfb..0f66119540b03 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -340,6 +340,7 @@ Status CheckInputs(const T* query, template Status CheckCustomAttentionInputs(const T* position_ids, const T* attention_bias, + const T* head_sink, const GroupQueryAttentionParameters& parameters) { if (position_ids != nullptr) { const auto& pos_ids_shape = position_ids->Shape(); @@ -377,6 +378,23 @@ Status CheckCustomAttentionInputs(const T* position_ids, } } + if (head_sink != nullptr) { + if (parameters.use_smooth_softmax) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_sink should not be provided when use_smooth_softmax is true."); + } + + const auto& head_sink_shape = head_sink->Shape(); + if (head_sink_shape.NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor"); + } + + if (head_sink_shape[0] != parameters.num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_sink dimension 0 must be equal to the num heads, got ", head_sink_shape[0]); + } + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 936b4483201ac..55bcf42f2f04b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { - if (seqlen_k != nullptr) { +void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) { + if (has_seqlen_k) { ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n"; } else { @@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - if (seqlen_k_ != nullptr) { + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); @@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str(); shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { @@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, seqlen_k != nullptr, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - if (seqlen_k_) { + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } + if (has_head_sink_) { + shader.AddInput("head_sink", ShaderUsage::UseUniform); + } shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n" << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n" + << "let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" << "}\n" << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var max_value = f32(-3.402823e+38f);\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "workgroupBarrier();\n"; + + if (has_head_sink_) { + // Handle head sink + shader.MainFunctionBody() << "let sink_value: f32 = head_sink[head_idx];\n" + << "var max_value = sink_value;\n"; + } else if (use_smooth_softmax_) { + shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; + } else { + shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + } + + shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" << " max_value = max(thread_max[i], max_value);\n" << "}\n" << "var sum_vector = f32_val_t(0);\n" @@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var sum: f32 = 0;\n" << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" << " sum += thread_sum[i]\n;" - << "}\n" - << "if (sum == 0) {\n" + << "}\n"; + + if (has_head_sink_) { + shader.MainFunctionBody() << "sum += exp(sink_value - max_value);\n"; + } else if (use_smooth_softmax_) { + shader.MainFunctionBody() << "sum += exp(-max_value);\n"; + } + + shader.MainFunctionBody() << "if (sum == 0) {\n" << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" << " }\n" @@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" << " }\n" << "}\n"; - if (seqlen_k_) { + if (has_seqlen_k_) { shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" << "}\n"; @@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, bool is_first_prompt) { + const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; @@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } + if (head_sink != nullptr) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(work_group_size) + .CacheHint(work_group_size, use_smooth_softmax) .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = @@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 7c0cb40cc7f93..e64ca3539c23d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -62,15 +62,15 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; - const Tensor* seqlen_k_; + bool has_seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { + InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink) + : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -86,7 +86,9 @@ class InPlaceSoftmaxProgram final : public Program { private: int work_group_size_; int components_; - const Tensor* seqlen_k_; + bool use_smooth_softmax_; + bool has_seqlen_k_; + bool has_head_sink_; }; class VxAttentionScoreProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 06b9c88ce8993..9d4740ede7143 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -123,7 +123,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, + const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index fabfbfbdd142e..c9e182bf10f2f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -382,6 +382,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi) // o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i // + + // TODO: support smooth softmax and head_sink shader.MainFunctionBody() << R"MAIN_FN( var local_max_temp = max(qk_1, qk_2); if (sg_size > 8) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f002db108035f..f3334b13dc645 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -152,6 +152,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* total_seqlen_tensor = context.Input(6); const Tensor* cos_cache = context.Input(7); const Tensor* sin_cache = context.Input(8); + const Tensor* position_ids = context.Input(9); // TODO: support sliding window + const Tensor* attention_bias = context.Input(10); + const Tensor* head_sink = context.Input(11); GroupQueryAttentionParameters params = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -168,6 +171,13 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& total_seqlen_tensor, scale_, softcap_)); + params.use_smooth_softmax = use_smooth_softmax_; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, + head_sink, + params)); + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); @@ -184,8 +194,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { - return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value, + if (!do_rotary_ && + head_sink == nullptr && !use_smooth_softmax_ && + CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } @@ -224,8 +236,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH( context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format - return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context, head_sink, seqlen_k); } TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, @@ -241,8 +253,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); - return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context, head_sink, seqlen_k); } } // namespace webgpu From 5ae3ee7cbe5fb4cabb80e8c9299e4b5888733a93 Mon Sep 17 00:00:00 2001 From: Ishwar Raut Date: Tue, 8 Jul 2025 06:01:19 +0530 Subject: [PATCH 08/10] 1. Fix Nv EP Build Break:wq (#25311) ### Description 1. Fix the Build Break in NV TRT RTX EP --- onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h index 1ab5e47a08523..3283aea7c33a2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -53,7 +53,7 @@ class CUDAExternalAllocator : public CUDAAllocator { // TODO: add a default constructor class CUDAPinnedAllocator : public IAllocator { public: - CUDAPinnedAllocator(const char* name, OrtDevice::DeviceId device_id) + CUDAPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, From 7c18d896b033dba80113efb720bd8842f8c23e33 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 7 Jul 2025 22:13:39 -0700 Subject: [PATCH 09/10] Fix cuda 12.9 windows build (#25317) ### Description Fix Windows build with MSVC 17.14.7 and cuda 12.9.1. The build error was like: `CUDACOMPILE : nvcc error : 'cudafe++' died with status 0xC0000005 (ACCESS_VIOLATION)` The cause is unknown (maybe cudafe bug). The code change resolved the issue. I've verified it in two machines. --- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 2611dde238f48..bfbe1d81b1c15 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -18,7 +18,7 @@ #include #include -#include +#include // for CUDA_VERSION #include #include #include @@ -38,19 +38,12 @@ #include "moe_kernel.h" -#if CUDA_VERSION >= 11000 #include #include #include -#else -#include "cub/cub.cuh" -#include "cub/device/device_radix_sort.cuh" -#include "cub/util_type.cuh" -#endif namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; - // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -65,13 +58,6 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; -#if CUDA_VERSION >= 12090 - ::cuda::std::plus sum; -#else - // Deprecated on CUDA 12.9 - cub::Sum sum; -#endif - float threadData(-FLT_MAX); // Don't touch finished rows. @@ -84,7 +70,12 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::maximum()); +#else const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); +#endif + if (threadIdx.x == 0) { float_max = maxElem; } @@ -97,7 +88,12 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::std::plus()); +#else + // Deprecated on CUDA 12.9 + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, cub::Sum()); +#endif if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; @@ -993,6 +989,7 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe if (experts_start_index > 0) { total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; } From 0ccecf71f4347abac08660a984058da4bb07ae9c Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 7 Jul 2025 23:02:15 -0700 Subject: [PATCH 10/10] [EP ABI] Infer OrtDevice for plugin EP from registered OrtMemoryInfo (#25308) ### Description - Infer `OrtDevice` for a plugin EP from the registered `OrtMemoryInfo` for device memory. - Fix potential `nullptr` dereference when a `PluginExecutionProvider` tries to log a message without a valid logger. Now, constructing a `PluginExecutionProvider` requires passing a valid logger. ### Motivation and Context Address a `TODO` to properly set the `OrtDevice` for a `PluginExecutionProvider` instance. --- .../core/framework/execution_provider.h | 4 + .../session/ep_plugin_provider_interfaces.cc | 45 +++++- .../session/ep_plugin_provider_interfaces.h | 2 +- .../core/session/provider_policy_context.cc | 3 +- .../test/framework/ep_plugin_provider_test.cc | 148 +++++++++++++++++- 5 files changed, 189 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 65a6ec304bda2..7df3368ad4e0b 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -79,6 +79,10 @@ class IExecutionProvider { : default_device_(device), type_{type} { } + IExecutionProvider(const std::string& type, OrtDevice device, const logging::Logger& logger) + : default_device_(device), type_{type}, logger_{&logger} { + } + /* default device for this ExecutionProvider */ diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index cac91a4ec52d2..878a5384dfee7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -53,11 +53,9 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_ ORT_THROW("Error creating execution provider: ", status.ToString()); } - auto ep_wrapper = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), - session_options, ep_factory_, devices_); - ep_wrapper->SetLogger(session_logger.ToInternal()); - - return ep_wrapper; + return std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), + session_options, ep_factory_, devices_, + *session_logger.ToInternal()); } /// @@ -86,10 +84,43 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // +static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { + // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. + // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. + + ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. + + const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; + + // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos + bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), + [mem_a = device_memory_info](const OrtEpDevice* ep_device) { + const OrtMemoryInfo* mem_b = ep_device->device_memory_info; + + if (mem_a == mem_b) { + return true; // Point to the same OrtMemoryInfo instance. + } + + if (mem_a == nullptr || mem_b == nullptr) { + return false; // One is nullptr and the other is not. + } + + // Both non-null but point to different instances. Use operator==. + return *mem_a == *mem_b; + }); + if (!all_match) { + ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name, + "': expected all OrtEpDevice instances to use the same device_memory_info."); + } + + return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices) - : IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins? + gsl::span ep_devices, + const logging::Logger& logger) + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), ep_devices_(ep_devices.begin(), ep_devices.end()) { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h index 343d6c9ad464e..3ba3118fcaa36 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -65,7 +65,7 @@ class PluginExecutionProvider : public IExecutionProvider { public: explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices); + gsl::span ep_devices, const logging::Logger& logger); ~PluginExecutionProvider(); std::vector> diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 6b54c33e9b10b..e8d62ab86f517 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -10,6 +10,7 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" #include "core/session/ep_factory_internal.h" #include "core/session/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" @@ -355,7 +356,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(), info.hardware_devices.size(), &options, &logger, &api_ep))); ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options, - *info.ep_factory, info.devices); + *info.ep_factory, info.devices, *logger.ToInternal()); } return Status::OK(); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 36b7f2965b483..18bc9cf05b36d 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -9,6 +9,7 @@ #include "core/session/abi_devices.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/asserts.h" +#include "test/util/include/test_environment.h" namespace onnxruntime::test { @@ -56,6 +57,32 @@ struct TestOrtEpFactory : ::OrtEpFactory { static TestOrtEpFactory g_test_ort_ep_factory{}; +std::unique_ptr MakeTestOrtHardwareDevice(OrtHardwareDeviceType type) { + auto hw_device = std::make_unique(); + hw_device->type = type; + hw_device->vendor_id = 0xBE57; + hw_device->device_id = 0; + hw_device->vendor = "Contoso"; + return hw_device; +} + +std::unique_ptr MakeTestOrtEpDevice(const OrtHardwareDevice* hardware_device, + const OrtMemoryInfo* device_memory_info = nullptr, + const OrtMemoryInfo* host_accessible_memory_info = nullptr) { + auto ep_device = std::make_unique(); + ep_device->ep_name = "TestOrtEp"; + ep_device->ep_vendor = "Contoso"; + ep_device->device = hardware_device; + ep_device->ep_factory = &g_test_ort_ep_factory; + ep_device->device_memory_info = device_memory_info; + ep_device->host_accessible_memory_info = host_accessible_memory_info; + return ep_device; +} + +OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::MemoryType memory_type) { + return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16); +} + struct MakeTestOrtEpResult { std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep` @@ -63,17 +90,25 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -MakeTestOrtEpResult MakeTestOrtEp() { +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { + // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. + static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); + static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); + auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); auto ort_session_options = Ort::SessionOptions{}; - auto ort_ep_device = OrtEpDevice{}; - std::vector ep_devices{&ort_ep_device}; + if (ep_devices.empty()) { + ep_devices.push_back(ort_ep_device.get()); + } + + auto& logging_manager = DefaultLoggingManager(); auto ep = std::make_unique(std::move(ort_ep), *static_cast(ort_session_options), g_test_ort_ep_factory, - ep_devices); + ep_devices, + logging_manager.DefaultLogger()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; @@ -177,4 +212,109 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { #endif // !defined(ORT_NO_EXCEPTIONS) } +TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { + // 1 OrtEpDevice without a device_memory_info. + // PluginExecutionProvider should decide to use a default OrtDevice. + { + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); + } + + // 1 OrtEpDevice with a device_memory_info. + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), + /*device_memory_info*/ ort_memory_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 2 OrtEpDevice instances with the same device_memory_info. + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 2 OrtEpDevice instances with the different (but equivalent) device_memory_info pointers. + // PluginExecutionProvider should decide to use a OrtDevice that is equal to the devices used by both + // device_memory_info pointers. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_0 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + auto ort_memory_info_1 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_0.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_1.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 1 OrtEpDevice with only a host_accessible_memory_info. + // PluginExecutionProvider should decide to use a default OrtDevice (cpu). + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), + /*device_memory_info*/ nullptr, + /*host_accessible_memory_info*/ ort_memory_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); + } + +#if !defined(ORT_NO_EXCEPTIONS) + // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances. + // Should throw an exception on construction of PluginExecutionProvider. + { + auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_gpu = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device_gpu, OrtMemTypeDefault); + + auto ort_device_npu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_npu = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device_npu, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_gpu.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException); + } +#endif // !defined(ORT_NO_EXCEPTIONS) +} + } // namespace onnxruntime::test