diff --git a/.github/policies/updateStaleIssues.yml b/.github/policies/updateStaleIssues.yml
new file mode 100644
index 0000000000000..2a041a50f93a7
--- /dev/null
+++ b/.github/policies/updateStaleIssues.yml
@@ -0,0 +1,65 @@
+name: Update Stale Issues
+description: Update stale issues
+resource: repository
+configuration:
+ resourceManagementConfiguration:
+ scheduledSearches:
+ - description: Apply stale label to open, unassigned issues that have not been updated in the last 30 days
+ frequencies:
+ - daily:
+ time: 15:00
+ filters:
+ - isIssue
+ - isOpen
+ - isNotAssigned
+ - isNotLabeledWith:
+ label: contributions welcome
+ - isNotLabeledWith:
+ label: documentation
+ - isNotLabeledWith:
+ label: feature request
+ - isNotLabeledWith:
+ label: regression
+ - noActivitySince:
+ days: 30
+ actions:
+ - addReply:
+ reply: "Applying stale label due to no activity in 30 days"
+ - addLabel:
+ label: stale
+ - description: Close open, unassigned issues labeled stale that have not been updated in the last 30 days
+ frequencies:
+ - daily:
+ time: 15:00
+ filters:
+ - hasLabel:
+ label: stale
+ - isIssue
+ - isOpen
+ - isNotAssigned
+ - noActivitySince:
+ days: 30
+ actions:
+ - addReply:
+ reply: "Closing issue due to no activity in 30 days"
+ - closeIssue
+ eventResponderTasks:
+ - description: Remove stale label if open stale issue is commented on
+ if:
+ - payloadType: Issue_Comment
+ - hasLabel:
+ label: stale
+ then:
+ - removeLabel:
+ label: stale
+ - description: Re-open stale issue if closed stale issue is commented on
+ if:
+ - payloadType: Issue_Comment
+ - and:
+ - not:
+ isOpen
+ - hasLabel:
+ label: stale
+ then:
+ - reopenIssue
+
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index bfb6e7c38ccb4..1ffcabee8cc10 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -828,6 +828,7 @@ Do not modify directly.*
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)|
+|RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**
or
*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index 65a6ec304bda2..7df3368ad4e0b 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -79,6 +79,10 @@ class IExecutionProvider {
: default_device_(device), type_{type} {
}
+ IExecutionProvider(const std::string& type, OrtDevice device, const logging::Logger& logger)
+ : default_device_(device), type_{type}, logger_{&logger} {
+ }
+
/*
default device for this ExecutionProvider
*/
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 7cdcbb3bc76bf..86c0b60db2bc4 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -6074,6 +6074,18 @@ struct OrtApi {
* \since Version 1.23
*/
ORT_API2_STATUS(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out);
+
+ /** \brief Get Session configuration entries.
+ *
+ * \param[in] options The session options.
+ * \param[out] out A pointer to a newly created OrtKeyValuePairs instance.
+ *
+ * An OrtKeyValuePairs instance containing all session configuration entries.
+ * Note: the user should call OrtApi::ReleaseKeyValuePairs.
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out);
};
/*
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 9c7530f0126bb..a912bd6e6b43c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -54,6 +54,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
const Tensor* sin_cache = context->Input(8);
const Tensor* position_ids = context->Input(9);
const Tensor* attention_bias = context->Input(10);
+ const Tensor* head_sink = context->Input(11);
GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
@@ -73,6 +74,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
+ head_sink,
parameters));
const int batch_size = parameters.batch_size;
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index 338c34acb3cfb..0f66119540b03 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -340,6 +340,7 @@ Status CheckInputs(const T* query,
template
Status CheckCustomAttentionInputs(const T* position_ids,
const T* attention_bias,
+ const T* head_sink,
const GroupQueryAttentionParameters& parameters) {
if (position_ids != nullptr) {
const auto& pos_ids_shape = position_ids->Shape();
@@ -377,6 +378,23 @@ Status CheckCustomAttentionInputs(const T* position_ids,
}
}
+ if (head_sink != nullptr) {
+ if (parameters.use_smooth_softmax) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_sink should not be provided when use_smooth_softmax is true.");
+ }
+
+ const auto& head_sink_shape = head_sink->Shape();
+ if (head_sink_shape.NumDimensions() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor");
+ }
+
+ if (head_sink_shape[0] != parameters.num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_sink dimension 0 must be equal to the num heads, got ", head_sink_shape[0]);
+ }
+ }
+
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
index 2611dde238f48..bfbe1d81b1c15 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
@@ -18,7 +18,7 @@
#include
#include
-#include
+#include // for CUDA_VERSION
#include
#include
#include
@@ -38,19 +38,12 @@
#include "moe_kernel.h"
-#if CUDA_VERSION >= 11000
#include
#include
#include
-#else
-#include "cub/cub.cuh"
-#include "cub/device/device_radix_sort.cuh"
-#include "cub/util_type.cuh"
-#endif
namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;
-
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
@@ -65,13 +58,6 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols;
-#if CUDA_VERSION >= 12090
- ::cuda::std::plus sum;
-#else
- // Deprecated on CUDA 12.9
- cub::Sum sum;
-#endif
-
float threadData(-FLT_MAX);
// Don't touch finished rows.
@@ -84,7 +70,12 @@ __launch_bounds__(TPB) __global__
threadData = max(static_cast(input[idx]), threadData);
}
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
+ const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::maximum());
+#else
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
+#endif
+
if (threadIdx.x == 0) {
float_max = maxElem;
}
@@ -97,7 +88,12 @@ __launch_bounds__(TPB) __global__
threadData += exp((static_cast(input[idx]) - float_max));
}
- const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
+ const auto Z = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::std::plus());
+#else
+ // Deprecated on CUDA 12.9
+ const auto Z = BlockReduce(tmpStorage).Reduce(threadData, cub::Sum());
+#endif
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
@@ -993,6 +989,7 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe
if (experts_start_index > 0) {
total_past_rows = total_rows_before_expert_host_[experts_start_index - 1];
}
+
total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows;
}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
index 936b4483201ac..55bcf42f2f04b 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
@@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};
-void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
- if (seqlen_k != nullptr) {
+void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
+ if (has_seqlen_k) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n";
} else {
@@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
- if (seqlen_k_ != nullptr) {
+ if (has_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
@@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
- InitVarStub(oss, seqlen_k_);
+ InitVarStub(oss, has_seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
@@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
- components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
+ components, parameters.is_first_prompt_, seqlen_k != nullptr, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
@@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
}
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
- if (seqlen_k_) {
+ if (has_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
+ if (has_head_sink_) {
+ shader.AddInput("head_sink", ShaderUsage::UseUniform);
+ }
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AdditionalImplementation() << "var thread_max: array;\n"
<< "var thread_sum: array;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n";
shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
<< "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
+ << "let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
- InitVarStub(oss, seqlen_k_);
+ InitVarStub(oss, has_seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
- << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
+ << "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
- << "workgroupBarrier();\n"
- << "var max_value = f32(-3.402823e+38f);\n"
- << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
+ << "workgroupBarrier();\n";
+
+ if (has_head_sink_) {
+ // Handle head sink
+ shader.MainFunctionBody() << "let sink_value: f32 = head_sink[head_idx];\n"
+ << "var max_value = sink_value;\n";
+ } else if (use_smooth_softmax_) {
+ shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n";
+ } else {
+ shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n";
+ }
+
+ shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " max_value = max(thread_max[i], max_value);\n"
<< "}\n"
<< "var sum_vector = f32_val_t(0);\n"
@@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "var sum: f32 = 0;\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " sum += thread_sum[i]\n;"
- << "}\n"
- << "if (sum == 0) {\n"
+ << "}\n";
+
+ if (has_head_sink_) {
+ shader.MainFunctionBody() << "sum += exp(sink_value - max_value);\n";
+ } else if (use_smooth_softmax_) {
+ shader.MainFunctionBody() << "sum += exp(-max_value);\n";
+ }
+
+ shader.MainFunctionBody() << "if (sum == 0) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
<< " }\n"
@@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " }\n"
<< "}\n";
- if (seqlen_k_) {
+ if (has_seqlen_k_) {
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
<< "}\n";
@@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
}
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
- const Tensor* seqlen_k, bool is_first_prompt) {
+ const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) {
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
@@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;
- InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
+ InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
+ if (head_sink != nullptr) {
+ program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
+ }
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
- .CacheHint(work_group_size)
+ .CacheHint(work_group_size, use_smooth_softmax)
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast(batch_size)},
@@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
- WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
+ WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length =
@@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
parameters, past_sequence_length, total_sequence_length, seqlen_k));
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
- parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));
+ parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink));
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h
index 7c0cb40cc7f93..e64ca3539c23d 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h
@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program
class AttentionProbsProgram final : public Program {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
- bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
- : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
+ bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false)
+ : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -62,15 +62,15 @@ class AttentionProbsProgram final : public Program {
bool has_attention_bias_;
int tile_size_;
int components_;
- const Tensor* seqlen_k_;
+ bool has_seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};
class InPlaceSoftmaxProgram final : public Program {
public:
- InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
- : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
+ InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)
+ : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -86,7 +86,9 @@ class InPlaceSoftmaxProgram final : public Program {
private:
int work_group_size_;
int components_;
- const Tensor* seqlen_k_;
+ bool use_smooth_softmax_;
+ bool has_seqlen_k_;
+ bool has_head_sink_;
};
class VxAttentionScoreProgram final : public Program {
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h
index 06b9c88ce8993..9d4740ede7143 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h
@@ -123,7 +123,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
- WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
+ WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
+ const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr);
} // namespace webgpu
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
index fabfbfbdd142e..c9e182bf10f2f 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
@@ -382,6 +382,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
// sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi)
// o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i
//
+
+ // TODO: support smooth softmax and head_sink
shader.MainFunctionBody() << R"MAIN_FN(
var local_max_temp = max(qk_1, qk_2);
if (sg_size > 8)
diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
index f002db108035f..f3334b13dc645 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
@@ -152,6 +152,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
const Tensor* total_seqlen_tensor = context.Input(6);
const Tensor* cos_cache = context.Input(7);
const Tensor* sin_cache = context.Input(8);
+ const Tensor* position_ids = context.Input(9); // TODO: support sliding window
+ const Tensor* attention_bias = context.Input(10);
+ const Tensor* head_sink = context.Input(11);
GroupQueryAttentionParameters params = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
@@ -168,6 +171,13 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
total_seqlen_tensor,
scale_,
softcap_));
+ params.use_smooth_softmax = use_smooth_softmax_;
+
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
+ attention_bias,
+ head_sink,
+ params));
+
WebgpuAttentionParameters parameters(params);
TensorShapeVector output_shape(3);
output_shape[0] = static_cast(parameters.batch_size_);
@@ -184,8 +194,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
- if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) {
- return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value,
+ if (!do_rotary_ &&
+ head_sink == nullptr && !use_smooth_softmax_ &&
+ CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
+ return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
}
@@ -224,8 +236,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
- return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key,
- present_value, parameters, context, seqlen_k);
+ return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
+ present_value, parameters, context, head_sink, seqlen_k);
}
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
@@ -241,8 +253,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
parameters.v_head_size_, value, nullptr, 0, &V));
- return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key,
- present_value, parameters, context, seqlen_k);
+ return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
+ present_value, parameters, context, head_sink, seqlen_k);
}
} // namespace webgpu
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
index 66364f7dab96e..02d02e824b357 100644
--- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
@@ -596,6 +596,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
uint32_t tile_size_k_vec = 16;
uint32_t tile_size = 32;
+ if (context.AdapterInfo().vendor == std::string_view{"intel"}) {
+ tile_size_k_vec = 32;
+ tile_size = 4;
+ }
+
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits, has_zero_points};
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
mul_program.SetWorkgroupSize(128);
diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc
index 471e4ee7c03a3..8f78f9c3b6cc7 100644
--- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc
+++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc
@@ -421,7 +421,13 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i
for (NodeIndex node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
- if (node == nullptr)
+
+ // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will
+ // produce the same results.
+ // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation,
+ // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob.
+ // Therefore, EPContext nodes are excluded from this process.
+ if (node == nullptr || node->OpType() == "EPContext")
continue;
ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
@@ -471,7 +477,12 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i
for (NodeIndex node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
- if (node == nullptr)
+ // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will
+ // produce the same results.
+ // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation,
+ // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob.
+ // Therefore, EPContext nodes are excluded from this process.
+ if (node == nullptr || node->OpType() == "EPContext")
continue;
bool node_output_replaced = false;
diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc
index 616374eee6ff1..f1b0d5850ab1f 100644
--- a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc
+++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc
@@ -72,13 +72,13 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* cos_data;
const T* sin_data;
int cache_offset;
- if (position_ids_format == 0) {
- cache_offset = (b * sequence_length + s) * half_rotary_emb_dim;
- } else {
- // Cache is (M, H/2) or (M, rotary_embedding_dim/2)
- const int position_id = static_cast(position_ids[b * sequence_length + s]);
- cache_offset = position_id * half_rotary_emb_dim;
+ // position_ids_format == 0 means position_ids is nullptr
+ // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
+ int b_s_index = b * sequence_length + s;
+ if (position_ids_format != 0) {
+ b_s_index = static_cast(position_ids[b_s_index]);
}
+ cache_offset = b_s_index * half_rotary_emb_dim;
cos_data = cos_cache + cache_offset;
sin_data = sin_cache + cache_offset;
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index d9acb9ccdc30f..1f4c9fcdbc073 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1490,6 +1490,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding);
#endif
@@ -2480,6 +2483,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
new file mode 100644
index 0000000000000..f259c6021a82e
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cpu/llm/rotary_embedding_helper.h"
+#include "core/providers/cuda/llm/rotary_embedding.h"
+#include "core/providers/cuda/llm/rotary_embedding_impl.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::rotary_embedding_helper;
+
+namespace onnxruntime {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RotaryEmbedding, \
+ kOnnxDomain, \
+ 23, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ RotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
+
+template
+RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
+ interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+}
+
+template
+Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* cos_cache = context->Input(1);
+ const Tensor* sin_cache = context->Input(2);
+ const Tensor* position_ids = context->Input(3); // Optional, can be nullptr
+
+ RotaryParameters parameters = {};
+ ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ num_heads,
+ rotary_embedding_dim,
+ ¶meters));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ // Launch rotary embedding kernel
+ typedef typename ToCudaType::MappedType CudaT;
+ auto& device_prop = GetDeviceProp();
+
+ // Handle optional position_ids - pass nullptr if position_ids is null
+ const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data() : nullptr;
+
+ return LaunchRotaryEmbeddingKernel(
+ Stream(context),
+ reinterpret_cast(output->template MutableData()),
+ reinterpret_cast(input->template Data()),
+ position_ids_data,
+ reinterpret_cast(cos_cache->template Data()),
+ reinterpret_cast(sin_cache->template Data()),
+ parameters.batch_size,
+ parameters.sequence_length,
+ parameters.num_heads,
+ parameters.head_size,
+ parameters.rotary_embedding_dim,
+ parameters.max_sequence_length,
+ parameters.position_ids_format,
+ interleaved,
+ device_prop.maxThreadsPerBlock,
+ parameters.transposed);
+}
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h
new file mode 100644
index 0000000000000..09bd2e7dfde15
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class RotaryEmbedding final : public CudaKernel {
+ public:
+ RotaryEmbedding(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ int num_heads;
+ int rotary_embedding_dim;
+ int interleaved;
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
new file mode 100644
index 0000000000000..eda049dfae9a8
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
@@ -0,0 +1,169 @@
+/*
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT License.
+*/
+
+/*
+Kernel implementation for rotary embeddings.
+*/
+
+#include "core/providers/cuda/llm/rotary_embedding_impl.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace cuda {
+
+template
+__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
+ const T* input, // BxSxNxH
+ const T* cos_cache, // BxSx(H/2) or Mx(H/2)
+ const T* sin_cache, // BxSx(H/2) or Mx(H/2)
+ const int64_t* position_ids, // (0) or BxS
+ const int sequence_length, const int num_heads, const int head_size,
+ const int rotary_embedding_dim, const int position_ids_format,
+ const bool interleaved,
+ int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
+) {
+ // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
+ // Use .x in innermost loop to access global memory efficiently
+
+ const int b = blockIdx.y;
+ const int s = blockIdx.x;
+ const int n = blockIdx.z;
+
+ const int i = threadIdx.x;
+
+ if (i >= head_size) {
+ return;
+ }
+
+ const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y;
+ T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y;
+
+ if (i >= rotary_embedding_dim) {
+ output_data[i] = input_data[i];
+ return;
+ }
+
+ // Cache is (M, H/2)
+ const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
+ int cache_offset;
+
+ // position_ids_format == 0 means position_ids is nullptr
+ // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
+ int b_s_index = b * sequence_length + s;
+ if (position_ids_format != 0) {
+ b_s_index = static_cast(position_ids[b_s_index]);
+ }
+ cache_offset = b_s_index * half_rotary_embedding_dim;
+ const T* cos_data = cos_cache + cache_offset;
+ const T* sin_data = sin_cache + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ if (interleaved) {
+ cache_idx = (i / 2) % half_rotary_embedding_dim;
+ sign = (i % 2 == 0) ? -1 : 1;
+ j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
+ } else {
+ cache_idx = i % half_rotary_embedding_dim;
+ sign = (i < half_rotary_embedding_dim) ? -1 : 1;
+ j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
+ }
+ output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+}
+
+template
+Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
+ const T* cos_cache, const T* sin_cache, const int batch_size,
+ const int sequence_length, const int num_heads, const int head_size,
+ const int rotary_embedding_dim, const int max_sequence_length,
+ const int position_ids_format, const bool interleaved,
+ const int max_threads_per_block, const bool is_input_bnsh_format) {
+ int4 in_strides;
+ int4 out_strides;
+ if (is_input_bnsh_format) {
+ // Semantic meaning of the strides:
+ // int in_head_stride = sequence_length * head_size;
+ // int out_head_stride = sequence_length * head_size;
+ // in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1};
+ // out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1};
+ // Simplify to:
+ in_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1};
+ out_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1};
+ } else {
+ // input is in bshn format
+ // int in_head_stride = head_size;
+ // int out_head_stride = head_size;
+ // Simplify to:
+ in_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1};
+ out_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1};
+ }
+ return LaunchRotaryEmbeddingKernel(
+ stream, output, input, position_ids,
+ cos_cache, sin_cache, batch_size,
+ sequence_length, num_heads, head_size,
+ rotary_embedding_dim, max_sequence_length,
+ position_ids_format, interleaved,
+ max_threads_per_block,
+ in_strides, out_strides);
+}
+
+template
+Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
+ const T* cos_cache, const T* sin_cache, const int batch_size,
+ const int sequence_length, const int num_heads, const int head_size,
+ const int rotary_embedding_dim, const int /*max_sequence_length*/,
+ const int position_ids_format, const bool interleaved,
+ const int max_threads_per_block,
+ int4 in_strides, int4 out_strides // strides in bnsh coord
+) {
+ // Note: Current implementation assumes head_size <= max_threads_per_block
+ // because head_size is currently large for LLaMA-2. For smaller head_size
+ // and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
+ // instead. This will require kernel changes to support.
+ ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
+ // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1)
+ ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous");
+
+ int tpb = (head_size + 31) / 32 * 32;
+
+ const dim3 block(tpb);
+ const dim3 grid(sequence_length, batch_size, num_heads);
+
+ assert(head_size <= max_threads_per_block);
+ RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
+ num_heads, head_size, rotary_embedding_dim, position_ids_format,
+ interleaved, in_strides, out_strides);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float* output, const float* input,
+ const int64_t* position_ids, const float* cos_cache,
+ const float* sin_cache, const int batch_size,
+ const int sequence_length, const int num_heads, const int head_size,
+ const int rotary_embedding_dim, const int max_sequence_length,
+ const int position_ids_format, const bool interleaved,
+ const int max_threads_per_block, const bool is_input_bnsh_format);
+
+template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half* output, const half* input,
+ const int64_t* position_ids, const half* cos_cache,
+ const half* sin_cache, const int batch_size,
+ const int sequence_length, const int num_heads, const int head_size,
+ const int rotary_embedding_dim, const int max_sequence_length,
+ const int position_ids_format, const bool interleaved,
+ const int max_threads_per_block, const bool is_input_bnsh_format);
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids,
+ const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length,
+ const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length,
+ const int position_ids_format, const bool interleaved, const int max_threads_per_block,
+ const bool is_input_bnsh_format);
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h
new file mode 100644
index 0000000000000..2389c0596014e
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h
@@ -0,0 +1,51 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int rotary_embedding_dim,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block,
+ const bool is_input_bnsh_format);
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int rotary_embedding_dim,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block,
+ int4 in_strides,
+ int4 out_strides);
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h
index 1ab5e47a08523..3283aea7c33a2 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h
@@ -53,7 +53,7 @@ class CUDAExternalAllocator : public CUDAAllocator {
// TODO: add a default constructor
class CUDAPinnedAllocator : public IAllocator {
public:
- CUDAPinnedAllocator(const char* name, OrtDevice::DeviceId device_id)
+ CUDAPinnedAllocator(OrtDevice::DeviceId device_id, const char* name)
: IAllocator(
OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc
index 0305049e9b789..e7736c3f3afac 100644
--- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc
+++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc
@@ -129,9 +129,6 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const {
int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions();
int max_components = GetMaxComponents(x_size);
- if (max_components != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DequantizeLinear: components must be 4, but got ", max_components);
- }
// scaler - single scaler for all elements
bool per_layer = x_scale_rank == 0 || (x_scale_rank == 1 && x_scale->Shape()[0] == 1);
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index 7a17423112144..3df6d37d63794 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -278,6 +278,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetSessionConfigEntry, _In_ const OrtSessionOptions
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out) {
+ API_IMPL_BEGIN
+ if (options == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "options is nullptr");
+ }
+ auto& config_options = options->value.config_options.GetConfigOptionsMap();
+ auto kvps = std::make_unique();
+ for (auto& kv : config_options) {
+ kvps->Add(kv.first.c_str(), kv.second.c_str());
+ }
+ *out = reinterpret_cast(kvps.release());
+ return nullptr;
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtApis::AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name,
_In_ const OrtValue* val) {
API_IMPL_BEGIN
diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc
index cac91a4ec52d2..878a5384dfee7 100644
--- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc
+++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc
@@ -53,11 +53,9 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
ORT_THROW("Error creating execution provider: ", status.ToString());
}
- auto ep_wrapper = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
- session_options, ep_factory_, devices_);
- ep_wrapper->SetLogger(session_logger.ToInternal());
-
- return ep_wrapper;
+ return std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
+ session_options, ep_factory_, devices_,
+ *session_logger.ToInternal());
}
///
@@ -86,10 +84,43 @@ struct PluginEpMetaDefNameFunctor {
// PluginExecutionProvider
//
+static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) {
+ // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU.
+ // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all.
+
+ ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices.
+
+ const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info;
+
+ // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos
+ bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(),
+ [mem_a = device_memory_info](const OrtEpDevice* ep_device) {
+ const OrtMemoryInfo* mem_b = ep_device->device_memory_info;
+
+ if (mem_a == mem_b) {
+ return true; // Point to the same OrtMemoryInfo instance.
+ }
+
+ if (mem_a == nullptr || mem_b == nullptr) {
+ return false; // One is nullptr and the other is not.
+ }
+
+ // Both non-null but point to different instances. Use operator==.
+ return *mem_a == *mem_b;
+ });
+ if (!all_match) {
+ ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name,
+ "': expected all OrtEpDevice instances to use the same device_memory_info.");
+ }
+
+ return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
+}
+
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options,
OrtEpFactory& ep_factory,
- gsl::span ep_devices)
- : IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
+ gsl::span ep_devices,
+ const logging::Logger& logger)
+ : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger),
ort_ep_(std::move(ep)),
ep_factory_(ep_factory),
ep_devices_(ep_devices.begin(), ep_devices.end()) {
diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h
index 343d6c9ad464e..3ba3118fcaa36 100644
--- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h
+++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h
@@ -65,7 +65,7 @@ class PluginExecutionProvider : public IExecutionProvider {
public:
explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory,
- gsl::span ep_devices);
+ gsl::span ep_devices, const logging::Logger& logger);
~PluginExecutionProvider();
std::vector>
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 2551fbc8b6099..e7f60fd48a14f 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -3632,6 +3632,8 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::ReleaseSharedAllocator,
&OrtApis::GetTensorData,
+
+ &OrtApis::GetSessionOptionsConfigEntries,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 4c4ab07493237..cbacbfce0740d 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -690,4 +690,6 @@ ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDe
_In_ OrtDeviceMemoryType mem_type);
ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out);
+
+ORT_API_STATUS_IMPL(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out);
} // namespace OrtApis
diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc
index 6b54c33e9b10b..e8d62ab86f517 100644
--- a/onnxruntime/core/session/provider_policy_context.cc
+++ b/onnxruntime/core/session/provider_policy_context.cc
@@ -10,6 +10,7 @@
#include "core/framework/error_code_helper.h"
#include "core/session/abi_devices.h"
+#include "core/session/abi_logger.h"
#include "core/session/ep_factory_internal.h"
#include "core/session/ep_plugin_provider_interfaces.h"
#include "core/session/inference_session.h"
@@ -355,7 +356,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or
info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(),
info.hardware_devices.size(), &options, &logger, &api_ep)));
ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options,
- *info.ep_factory, info.devices);
+ *info.ep_factory, info.devices, *logger.ToInternal());
}
return Status::OK();
diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc
index 36b7f2965b483..18bc9cf05b36d 100644
--- a/onnxruntime/test/framework/ep_plugin_provider_test.cc
+++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc
@@ -9,6 +9,7 @@
#include "core/session/abi_devices.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/util/include/asserts.h"
+#include "test/util/include/test_environment.h"
namespace onnxruntime::test {
@@ -56,6 +57,32 @@ struct TestOrtEpFactory : ::OrtEpFactory {
static TestOrtEpFactory g_test_ort_ep_factory{};
+std::unique_ptr MakeTestOrtHardwareDevice(OrtHardwareDeviceType type) {
+ auto hw_device = std::make_unique();
+ hw_device->type = type;
+ hw_device->vendor_id = 0xBE57;
+ hw_device->device_id = 0;
+ hw_device->vendor = "Contoso";
+ return hw_device;
+}
+
+std::unique_ptr MakeTestOrtEpDevice(const OrtHardwareDevice* hardware_device,
+ const OrtMemoryInfo* device_memory_info = nullptr,
+ const OrtMemoryInfo* host_accessible_memory_info = nullptr) {
+ auto ep_device = std::make_unique();
+ ep_device->ep_name = "TestOrtEp";
+ ep_device->ep_vendor = "Contoso";
+ ep_device->device = hardware_device;
+ ep_device->ep_factory = &g_test_ort_ep_factory;
+ ep_device->device_memory_info = device_memory_info;
+ ep_device->host_accessible_memory_info = host_accessible_memory_info;
+ return ep_device;
+}
+
+OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::MemoryType memory_type) {
+ return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16);
+}
+
struct MakeTestOrtEpResult {
std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp
gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep`
@@ -63,17 +90,25 @@ struct MakeTestOrtEpResult {
// Creates an IExecutionProvider that wraps a TestOrtEp.
// The TestOrtEp is also exposed so that tests can manipulate its function pointers directly.
-MakeTestOrtEpResult MakeTestOrtEp() {
+MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) {
+ // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices.
+ static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU);
+ static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get());
+
auto ort_ep_raw = std::make_unique().release();
auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory});
auto ort_session_options = Ort::SessionOptions{};
- auto ort_ep_device = OrtEpDevice{};
- std::vector ep_devices{&ort_ep_device};
+ if (ep_devices.empty()) {
+ ep_devices.push_back(ort_ep_device.get());
+ }
+
+ auto& logging_manager = DefaultLoggingManager();
auto ep = std::make_unique(std::move(ort_ep),
*static_cast(ort_session_options),
g_test_ort_ep_factory,
- ep_devices);
+ ep_devices,
+ logging_manager.DefaultLogger());
auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw};
return result;
@@ -177,4 +212,109 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) {
#endif // !defined(ORT_NO_EXCEPTIONS)
}
+TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
+ // 1 OrtEpDevice without a device_memory_info.
+ // PluginExecutionProvider should decide to use a default OrtDevice.
+ {
+ auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU);
+ auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get());
+ std::vector ep_devices{ort_ep_device.get()};
+
+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
+ ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice());
+ }
+
+ // 1 OrtEpDevice with a device_memory_info.
+ // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info.
+ {
+ auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT);
+ auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device, OrtMemTypeDefault);
+
+ auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
+ auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(),
+ /*device_memory_info*/ ort_memory_info.get());
+ std::vector ep_devices{ort_ep_device.get()};
+
+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
+ ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
+ }
+
+ // 2 OrtEpDevice instances with the same device_memory_info.
+ // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info.
+ {
+ auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT);
+ auto ort_memory_info = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device, OrtMemTypeDefault);
+
+ auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
+ auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
+ auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info.get());
+ auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info.get());
+ std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
+
+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
+ ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
+ }
+
+ // 2 OrtEpDevice instances with the different (but equivalent) device_memory_info pointers.
+ // PluginExecutionProvider should decide to use a OrtDevice that is equal to the devices used by both
+ // device_memory_info pointers.
+ {
+ auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT);
+ auto ort_memory_info_0 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device, OrtMemTypeDefault);
+ auto ort_memory_info_1 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device, OrtMemTypeDefault);
+
+ auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
+ auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
+ auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_0.get());
+ auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_1.get());
+ std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
+
+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
+ ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
+ }
+
+ // 1 OrtEpDevice with only a host_accessible_memory_info.
+ // PluginExecutionProvider should decide to use a default OrtDevice (cpu).
+ {
+ auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE);
+ auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device, OrtMemTypeDefault);
+
+ auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
+ auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(),
+ /*device_memory_info*/ nullptr,
+ /*host_accessible_memory_info*/ ort_memory_info.get());
+ std::vector ep_devices{ort_ep_device.get()};
+
+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
+ ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice());
+ }
+
+#if !defined(ORT_NO_EXCEPTIONS)
+ // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances.
+ // Should throw an exception on construction of PluginExecutionProvider.
+ {
+ auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT);
+ auto ort_memory_info_gpu = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device_gpu, OrtMemTypeDefault);
+
+ auto ort_device_npu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT);
+ auto ort_memory_info_npu = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator,
+ ort_device_npu, OrtMemTypeDefault);
+
+ auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
+ auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
+ auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_gpu.get());
+ auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get());
+ std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
+
+ ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException);
+ }
+#endif // !defined(ORT_NO_EXCEPTIONS)
+}
+
} // namespace onnxruntime::test