diff --git a/.github/policies/updateStaleIssues.yml b/.github/policies/updateStaleIssues.yml new file mode 100644 index 0000000000000..2a041a50f93a7 --- /dev/null +++ b/.github/policies/updateStaleIssues.yml @@ -0,0 +1,65 @@ +name: Update Stale Issues +description: Update stale issues +resource: repository +configuration: + resourceManagementConfiguration: + scheduledSearches: + - description: Apply stale label to open, unassigned issues that have not been updated in the last 30 days + frequencies: + - daily: + time: 15:00 + filters: + - isIssue + - isOpen + - isNotAssigned + - isNotLabeledWith: + label: contributions welcome + - isNotLabeledWith: + label: documentation + - isNotLabeledWith: + label: feature request + - isNotLabeledWith: + label: regression + - noActivitySince: + days: 30 + actions: + - addReply: + reply: "Applying stale label due to no activity in 30 days" + - addLabel: + label: stale + - description: Close open, unassigned issues labeled stale that have not been updated in the last 30 days + frequencies: + - daily: + time: 15:00 + filters: + - hasLabel: + label: stale + - isIssue + - isOpen + - isNotAssigned + - noActivitySince: + days: 30 + actions: + - addReply: + reply: "Closing issue due to no activity in 30 days" + - closeIssue + eventResponderTasks: + - description: Remove stale label if open stale issue is commented on + if: + - payloadType: Issue_Comment + - hasLabel: + label: stale + then: + - removeLabel: + label: stale + - description: Re-open stale issue if closed stale issue is commented on + if: + - payloadType: Issue_Comment + - and: + - not: + isOpen + - hasLabel: + label: stale + then: + - reopenIssue + diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb6e7c38ccb4..1ffcabee8cc10 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -828,6 +828,7 @@ Do not modify directly.* |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 65a6ec304bda2..7df3368ad4e0b 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -79,6 +79,10 @@ class IExecutionProvider { : default_device_(device), type_{type} { } + IExecutionProvider(const std::string& type, OrtDevice device, const logging::Logger& logger) + : default_device_(device), type_{type}, logger_{&logger} { + } + /* default device for this ExecutionProvider */ diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7cdcbb3bc76bf..86c0b60db2bc4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6074,6 +6074,18 @@ struct OrtApi { * \since Version 1.23 */ ORT_API2_STATUS(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); + + /** \brief Get Session configuration entries. + * + * \param[in] options The session options. + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance. + * + * An OrtKeyValuePairs instance containing all session configuration entries. + * Note: the user should call OrtApi::ReleaseKeyValuePairs. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); }; /* diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 9c7530f0126bb..a912bd6e6b43c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -54,6 +54,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* sin_cache = context->Input(8); const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); + const Tensor* head_sink = context->Input(11); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -73,6 +74,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, + head_sink, parameters)); const int batch_size = parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 338c34acb3cfb..0f66119540b03 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -340,6 +340,7 @@ Status CheckInputs(const T* query, template Status CheckCustomAttentionInputs(const T* position_ids, const T* attention_bias, + const T* head_sink, const GroupQueryAttentionParameters& parameters) { if (position_ids != nullptr) { const auto& pos_ids_shape = position_ids->Shape(); @@ -377,6 +378,23 @@ Status CheckCustomAttentionInputs(const T* position_ids, } } + if (head_sink != nullptr) { + if (parameters.use_smooth_softmax) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_sink should not be provided when use_smooth_softmax is true."); + } + + const auto& head_sink_shape = head_sink->Shape(); + if (head_sink_shape.NumDimensions() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor"); + } + + if (head_sink_shape[0] != parameters.num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_sink dimension 0 must be equal to the num heads, got ", head_sink_shape[0]); + } + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 2611dde238f48..bfbe1d81b1c15 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -18,7 +18,7 @@ #include #include -#include +#include // for CUDA_VERSION #include #include #include @@ -38,19 +38,12 @@ #include "moe_kernel.h" -#if CUDA_VERSION >= 11000 #include #include #include -#else -#include "cub/cub.cuh" -#include "cub/device/device_radix_sort.cuh" -#include "cub/util_type.cuh" -#endif namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; - // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -65,13 +58,6 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; -#if CUDA_VERSION >= 12090 - ::cuda::std::plus sum; -#else - // Deprecated on CUDA 12.9 - cub::Sum sum; -#endif - float threadData(-FLT_MAX); // Don't touch finished rows. @@ -84,7 +70,12 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::maximum()); +#else const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); +#endif + if (threadIdx.x == 0) { float_max = maxElem; } @@ -97,7 +88,12 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::std::plus()); +#else + // Deprecated on CUDA 12.9 + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, cub::Sum()); +#endif if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; @@ -993,6 +989,7 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe if (experts_start_index > 0) { total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 936b4483201ac..55bcf42f2f04b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; -void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) { - if (seqlen_k != nullptr) { +void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) { + if (has_seqlen_k) { ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n"; } else { @@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - if (seqlen_k_ != nullptr) { + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); @@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str(); shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { @@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, seqlen_k != nullptr, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - if (seqlen_k_) { + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } + if (has_head_sink_) { + shader.AddInput("head_sink", ShaderUsage::UseUniform); + } shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n" << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n" + << "let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; - InitVarStub(oss, seqlen_k_); + InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" << "}\n" << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var max_value = f32(-3.402823e+38f);\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "workgroupBarrier();\n"; + + if (has_head_sink_) { + // Handle head sink + shader.MainFunctionBody() << "let sink_value: f32 = head_sink[head_idx];\n" + << "var max_value = sink_value;\n"; + } else if (use_smooth_softmax_) { + shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; + } else { + shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + } + + shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" << " max_value = max(thread_max[i], max_value);\n" << "}\n" << "var sum_vector = f32_val_t(0);\n" @@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var sum: f32 = 0;\n" << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" << " sum += thread_sum[i]\n;" - << "}\n" - << "if (sum == 0) {\n" + << "}\n"; + + if (has_head_sink_) { + shader.MainFunctionBody() << "sum += exp(sink_value - max_value);\n"; + } else if (use_smooth_softmax_) { + shader.MainFunctionBody() << "sum += exp(-max_value);\n"; + } + + shader.MainFunctionBody() << "if (sum == 0) {\n" << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" << " }\n" @@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" << " }\n" << "}\n"; - if (seqlen_k_) { + if (has_seqlen_k_) { shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" << "}\n"; @@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, bool is_first_prompt) { + const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; @@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k}; + InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } + if (head_sink != nullptr) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(work_group_size) + .CacheHint(work_group_size, use_smooth_softmax) .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = @@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 7c0cb40cc7f93..e64ca3539c23d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -62,15 +62,15 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; - const Tensor* seqlen_k_; + bool has_seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; }; class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { + InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink) + : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -86,7 +86,9 @@ class InPlaceSoftmaxProgram final : public Program { private: int work_group_size_; int components_; - const Tensor* seqlen_k_; + bool use_smooth_softmax_; + bool has_seqlen_k_; + bool has_head_sink_; }; class VxAttentionScoreProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 06b9c88ce8993..9d4740ede7143 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -123,7 +123,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, + const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index fabfbfbdd142e..c9e182bf10f2f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -382,6 +382,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi) // o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i // + + // TODO: support smooth softmax and head_sink shader.MainFunctionBody() << R"MAIN_FN( var local_max_temp = max(qk_1, qk_2); if (sg_size > 8) diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f002db108035f..f3334b13dc645 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -152,6 +152,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* total_seqlen_tensor = context.Input(6); const Tensor* cos_cache = context.Input(7); const Tensor* sin_cache = context.Input(8); + const Tensor* position_ids = context.Input(9); // TODO: support sliding window + const Tensor* attention_bias = context.Input(10); + const Tensor* head_sink = context.Input(11); GroupQueryAttentionParameters params = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -168,6 +171,13 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& total_seqlen_tensor, scale_, softcap_)); + params.use_smooth_softmax = use_smooth_softmax_; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, + head_sink, + params)); + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); @@ -184,8 +194,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); - if (!do_rotary_ && CanApplyFlashAttention(nullptr /* bias */, present_key, present_value, parameters, context)) { - return ApplyFlashAttention(query, key, value, nullptr /* attention_bias */, output, past_key, present_key, past_value, + if (!do_rotary_ && + head_sink == nullptr && !use_smooth_softmax_ && + CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { + return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } @@ -224,8 +236,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH( context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format - return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context, head_sink, seqlen_k); } TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, @@ -241,8 +253,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); - return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key, - present_value, parameters, context, seqlen_k); + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context, head_sink, seqlen_k); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 66364f7dab96e..02d02e824b357 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -596,6 +596,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t tile_size_k_vec = 16; uint32_t tile_size = 32; + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + tile_size_k_vec = 32; + tile_size = 4; + } + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits, has_zero_points}; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; mul_program.SetWorkgroupSize(128); diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index 471e4ee7c03a3..8f78f9c3b6cc7 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -421,7 +421,13 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i for (NodeIndex node_index : node_topology_list) { Node* node = graph.GetNode(node_index); - if (node == nullptr) + + // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will + // produce the same results. + // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation, + // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob. + // Therefore, EPContext nodes are excluded from this process. + if (node == nullptr || node->OpType() == "EPContext") continue; ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); @@ -471,7 +477,12 @@ Status CommonSubexpressionElimination::ApplyImpl(Graph& graph, bool& modified, i for (NodeIndex node_index : node_topology_list) { Node* node = graph.GetNode(node_index); - if (node == nullptr) + // In the context of a model containing EPContext nodes, it's highly unlikely that two EPContext nodes will + // produce the same results. + // Furthermore, the EquivalenceClass constructor includes the node and all its attributes in the hash calculation, + // which can be particularly time-consuming when the "ep_cache_context" attribute contains a large binary blob. + // Therefore, EPContext nodes are excluded from this process. + if (node == nullptr || node->OpType() == "EPContext") continue; bool node_output_replaced = false; diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc index 616374eee6ff1..f1b0d5850ab1f 100644 --- a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc @@ -72,13 +72,13 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* cos_data; const T* sin_data; int cache_offset; - if (position_ids_format == 0) { - cache_offset = (b * sequence_length + s) * half_rotary_emb_dim; - } else { - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) - const int position_id = static_cast(position_ids[b * sequence_length + s]); - cache_offset = position_id * half_rotary_emb_dim; + // position_ids_format == 0 means position_ids is nullptr + // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) + int b_s_index = b * sequence_length + s; + if (position_ids_format != 0) { + b_s_index = static_cast(position_ids[b_s_index]); } + cache_offset = b_s_index * half_rotary_emb_dim; cos_data = cos_cache + cache_offset; sin_data = sin_cache + cache_offset; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d9acb9ccdc30f..1f4c9fcdbc073 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1490,6 +1490,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding); #endif @@ -2480,6 +2483,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc new file mode 100644 index 0000000000000..f259c6021a82e --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/llm/rotary_embedding_helper.h" +#include "core/providers/cuda/llm/rotary_embedding.h" +#include "core/providers/cuda/llm/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::rotary_embedding_helper; + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kOnnxDomain, \ + 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* cos_cache = context->Input(1); + const Tensor* sin_cache = context->Input(2); + const Tensor* position_ids = context->Input(3); // Optional, can be nullptr + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + num_heads, + rotary_embedding_dim, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + + // Handle optional position_ids - pass nullptr if position_ids is null + const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data() : nullptr; + + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids_data, + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.rotary_embedding_dim, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock, + parameters.transposed); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h new file mode 100644 index 0000000000000..09bd2e7dfde15 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads; + int rotary_embedding_dim; + int interleaved; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..eda049dfae9a8 --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu @@ -0,0 +1,169 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include "core/providers/cuda/llm/rotary_embedding_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // BxSx(H/2) or Mx(H/2) + const T* sin_cache, // BxSx(H/2) or Mx(H/2) + const int64_t* position_ids, // (0) or BxS + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.y; + const int s = blockIdx.x; + const int n = blockIdx.z; + + const int i = threadIdx.x; + + if (i >= head_size) { + return; + } + + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; + + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + + // Cache is (M, H/2) + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + int cache_offset; + + // position_ids_format == 0 means position_ids is nullptr + // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) + int b_s_index = b * sequence_length + s; + if (position_ids_format != 0) { + b_s_index = static_cast(position_ids[b_s_index]); + } + cache_offset = b_s_index * half_rotary_embedding_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_rotary_embedding_dim; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + // Semantic meaning of the strides: + // int in_head_stride = sequence_length * head_size; + // int out_head_stride = sequence_length * head_size; + // in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + // out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + // Simplify to: + in_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; + out_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; + } else { + // input is in bshn format + // int in_head_stride = head_size; + // int out_head_stride = head_size; + // Simplify to: + in_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; + out_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); + + int tpb = (head_size + 31) / 32 * 32; + + const dim3 block(tpb); + const dim3 grid(sequence_length, batch_size, num_heads); + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, float* output, const float* input, + const int64_t* position_ids, const float* cos_cache, + const float* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format); + +template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, half* output, const half* input, + const int64_t* position_ids, const half* cos_cache, + const half* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, + const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length, + const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, + const int position_ids_format, const bool interleaved, const int max_threads_per_block, + const bool is_input_bnsh_format); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h new file mode 100644 index 0000000000000..2389c0596014e --- /dev/null +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool is_input_bnsh_format); + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + int4 in_strides, + int4 out_strides); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h index 1ab5e47a08523..3283aea7c33a2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -53,7 +53,7 @@ class CUDAExternalAllocator : public CUDAAllocator { // TODO: add a default constructor class CUDAPinnedAllocator : public IAllocator { public: - CUDAPinnedAllocator(const char* name, OrtDevice::DeviceId device_id) + CUDAPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index 0305049e9b789..e7736c3f3afac 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -129,9 +129,6 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions(); int max_components = GetMaxComponents(x_size); - if (max_components != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DequantizeLinear: components must be 4, but got ", max_components); - } // scaler - single scaler for all elements bool per_layer = x_scale_rank == 0 || (x_scale_rank == 1 && x_scale->Shape()[0] == 1); diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 7a17423112144..3df6d37d63794 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -278,6 +278,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetSessionConfigEntry, _In_ const OrtSessionOptions API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out) { + API_IMPL_BEGIN + if (options == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "options is nullptr"); + } + auto& config_options = options->value.config_options.GetConfigOptionsMap(); + auto kvps = std::make_unique(); + for (auto& kv : config_options) { + kvps->Add(kv.first.c_str(), kv.second.c_str()); + } + *out = reinterpret_cast(kvps.release()); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, _In_ const OrtValue* val) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index cac91a4ec52d2..878a5384dfee7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -53,11 +53,9 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_ ORT_THROW("Error creating execution provider: ", status.ToString()); } - auto ep_wrapper = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), - session_options, ep_factory_, devices_); - ep_wrapper->SetLogger(session_logger.ToInternal()); - - return ep_wrapper; + return std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), + session_options, ep_factory_, devices_, + *session_logger.ToInternal()); } /// @@ -86,10 +84,43 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // +static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { + // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. + // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. + + ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. + + const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; + + // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos + bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), + [mem_a = device_memory_info](const OrtEpDevice* ep_device) { + const OrtMemoryInfo* mem_b = ep_device->device_memory_info; + + if (mem_a == mem_b) { + return true; // Point to the same OrtMemoryInfo instance. + } + + if (mem_a == nullptr || mem_b == nullptr) { + return false; // One is nullptr and the other is not. + } + + // Both non-null but point to different instances. Use operator==. + return *mem_a == *mem_b; + }); + if (!all_match) { + ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name, + "': expected all OrtEpDevice instances to use the same device_memory_info."); + } + + return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices) - : IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins? + gsl::span ep_devices, + const logging::Logger& logger) + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), ep_devices_(ep_devices.begin(), ep_devices.end()) { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h index 343d6c9ad464e..3ba3118fcaa36 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -65,7 +65,7 @@ class PluginExecutionProvider : public IExecutionProvider { public: explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, - gsl::span ep_devices); + gsl::span ep_devices, const logging::Logger& logger); ~PluginExecutionProvider(); std::vector> diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2551fbc8b6099..e7f60fd48a14f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3632,6 +3632,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSharedAllocator, &OrtApis::GetTensorData, + + &OrtApis::GetSessionOptionsConfigEntries, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4c4ab07493237..cbacbfce0740d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -690,4 +690,6 @@ ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDe _In_ OrtDeviceMemoryType mem_type); ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); + +ORT_API_STATUS_IMPL(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 6b54c33e9b10b..e8d62ab86f517 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -10,6 +10,7 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" #include "core/session/ep_factory_internal.h" #include "core/session/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" @@ -355,7 +356,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(), info.hardware_devices.size(), &options, &logger, &api_ep))); ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options, - *info.ep_factory, info.devices); + *info.ep_factory, info.devices, *logger.ToInternal()); } return Status::OK(); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 36b7f2965b483..18bc9cf05b36d 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -9,6 +9,7 @@ #include "core/session/abi_devices.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/asserts.h" +#include "test/util/include/test_environment.h" namespace onnxruntime::test { @@ -56,6 +57,32 @@ struct TestOrtEpFactory : ::OrtEpFactory { static TestOrtEpFactory g_test_ort_ep_factory{}; +std::unique_ptr MakeTestOrtHardwareDevice(OrtHardwareDeviceType type) { + auto hw_device = std::make_unique(); + hw_device->type = type; + hw_device->vendor_id = 0xBE57; + hw_device->device_id = 0; + hw_device->vendor = "Contoso"; + return hw_device; +} + +std::unique_ptr MakeTestOrtEpDevice(const OrtHardwareDevice* hardware_device, + const OrtMemoryInfo* device_memory_info = nullptr, + const OrtMemoryInfo* host_accessible_memory_info = nullptr) { + auto ep_device = std::make_unique(); + ep_device->ep_name = "TestOrtEp"; + ep_device->ep_vendor = "Contoso"; + ep_device->device = hardware_device; + ep_device->ep_factory = &g_test_ort_ep_factory; + ep_device->device_memory_info = device_memory_info; + ep_device->host_accessible_memory_info = host_accessible_memory_info; + return ep_device; +} + +OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::MemoryType memory_type) { + return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16); +} + struct MakeTestOrtEpResult { std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep` @@ -63,17 +90,25 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -MakeTestOrtEpResult MakeTestOrtEp() { +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { + // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. + static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); + static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); + auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); auto ort_session_options = Ort::SessionOptions{}; - auto ort_ep_device = OrtEpDevice{}; - std::vector ep_devices{&ort_ep_device}; + if (ep_devices.empty()) { + ep_devices.push_back(ort_ep_device.get()); + } + + auto& logging_manager = DefaultLoggingManager(); auto ep = std::make_unique(std::move(ort_ep), *static_cast(ort_session_options), g_test_ort_ep_factory, - ep_devices); + ep_devices, + logging_manager.DefaultLogger()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; @@ -177,4 +212,109 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { #endif // !defined(ORT_NO_EXCEPTIONS) } +TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { + // 1 OrtEpDevice without a device_memory_info. + // PluginExecutionProvider should decide to use a default OrtDevice. + { + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); + } + + // 1 OrtEpDevice with a device_memory_info. + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), + /*device_memory_info*/ ort_memory_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 2 OrtEpDevice instances with the same device_memory_info. + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 2 OrtEpDevice instances with the different (but equivalent) device_memory_info pointers. + // PluginExecutionProvider should decide to use a OrtDevice that is equal to the devices used by both + // device_memory_info pointers. + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_0 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + auto ort_memory_info_1 = std::make_unique("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_0.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_1.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + } + + // 1 OrtEpDevice with only a host_accessible_memory_info. + // PluginExecutionProvider should decide to use a default OrtDevice (cpu). + { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto ort_memory_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), + /*device_memory_info*/ nullptr, + /*host_accessible_memory_info*/ ort_memory_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); + } + +#if !defined(ORT_NO_EXCEPTIONS) + // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances. + // Should throw an exception on construction of PluginExecutionProvider. + { + auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_gpu = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device_gpu, OrtMemTypeDefault); + + auto ort_device_npu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto ort_memory_info_npu = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + ort_device_npu, OrtMemTypeDefault); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_gpu.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException); + } +#endif // !defined(ORT_NO_EXCEPTIONS) +} + } // namespace onnxruntime::test