From b9b753085e79e17e99c206dc51c6d5fd67641f92 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 9 Jul 2025 19:04:45 +0800 Subject: [PATCH 01/49] [webgpu] Update wgsl_templates README.md (#25336) ### Description Fix a broken URL and numbering in the ordered list in README.md. ### Motivation and Context See Above. --- onnxruntime/core/providers/webgpu/wgsl_templates/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md index c1a62e7fa7858..6bd2f98cc5713 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md @@ -64,7 +64,7 @@ This section includes instructions for how to use the template system in the dev 1. Create WGSL template files in `.wgsl.template` extension. - [Reference: Template Syntax](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#template-syntax) - - [Reference: Built-in Utilities](#Utilities) + - [Reference: Built-in Utilities](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#Utilities) - [Example: Pad](../tensor/pad.wgsl.template) 2. In the implementation of `YourProgram::GenerateShaderCode()`, load and use the generated template files. @@ -117,4 +117,4 @@ This section includes instructions for how to use the template system in the dev 1. Build ORT once with dynamic template mode 2. Launch wgsl-gen in watch mode 3. Run ORT to debug/validate the shader - 4. Make changes to the template files, and repeat step (3) + 4. Make changes to the template files, and repeat step (c) From bc0256fe84be0d7daf36bd491cfd078103b5d635 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 9 Jul 2025 23:18:34 +0800 Subject: [PATCH 02/49] [webgpu] Move the early return after copying for ScatterND (#25345) ### Description For ScatterND, if the indices are empty (nothing to update), it becomes a copy operation. So we should move the early return after copying. --- .../providers/webgpu/tensor/scatter_nd.cc | 22 +++++++++---------- .../cpu/tensor/scatter_nd_op_test.cc | 11 ++++++++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc index f13e86c185928..9f07e2d2a3988 100644 --- a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const { const auto* updates = context.Input(2); const auto& input_shape = input->Shape(); const auto& indices_shape = indices->Shape(); - auto indices_rank = indices_shape.NumDimensions(); - auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); - auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); - // TODO: support bool with components 4. - const size_t components = 1; - auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); auto* output = context.Output(0, input_shape); - if (output_size == 0) { - // If the output tensor is empty, we can return early. - return Status::OK(); - } - MLDataType data_type = input->DataType(); const void* source = input->DataRaw(); void* target = output->MutableDataRaw(); // If source and target pointers are not equal (non-inplace operation), we need to copy the data. if (target != source) { ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); } + if (indices_shape.Size() == 0) { + // If the indices are empty, we can return early. + return Status::OK(); + } + auto indices_rank = indices_shape.NumDimensions(); + auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); + auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); + // TODO: support bool with components 4. + const size_t components = 1; + auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + MLDataType data_type = input->DataType(); ScatterNDProgram program(reduction_, data_type); program .CacheHint(static_cast(reduction_)) diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index 895c8ab3e53e4..e6d113e1e4dca 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -235,5 +235,16 @@ TEST(ScatterNDOpTest, ScatterND_18_max) { test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } +// Test for ScatterND with empty indices - output should be same as input +TEST(ScatterNDOpTest, ScatterND_empty_indices) { + // Test with float data type and minimal empty case + OpTester test1("ScatterND", 11); + test1.AddInput("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test1.AddInput("indices", {0, 1}, {}); // Empty indices tensor - no indices to process + test1.AddInput("updates", {0, 3}, {}); // Empty updates tensor + test1.AddOutput("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + } // namespace test } // namespace onnxruntime From 3b259e11209f50a6e1c6d3a712ceeef63532ce60 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 9 Jul 2025 09:42:35 -0700 Subject: [PATCH 03/49] [EP ABI] Utility to serialize OrtGraph to GraphProto (#25292) ### Description - Provides utility functions that serialize an `OrtGraph` to a `GraphProto` or `ModelProto`. - Header-only file that can be copied to a project that builds with ORT and ONNX. - Available in [include/onnxruntime/core/providers/utils/ort_graph_to_proto.h](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-ort-graph-to-onnx-protobuf/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h) - Updates the `Node_GetSubgraphs` API function to also return the attribute names associated with each subgraph. This is required to determine which subgraph corresponds to a given attribute. - Adds `Graph_GetNumOperatorSets` and `Graph_GetOperatorSets` API functions to get the opset version for each domain. ### Motivation and Context Provide a utility to facilitate porting of existing execution providers to the new EP ABI. The utilities introduced by this PR convert an `OrtGraph` into an ONNX protobuf representation, which some existing EPs currently convert to their internal representation. Ideally, we would prefer a more direct conversion from a `OrtGraph` to the EP's internal representation, but this is a large effort. These utilities enable an incremental transition. --- .../core/providers/utils/ort_graph_to_proto.h | 718 ++++++++++++++++++ .../core/session/onnxruntime_c_api.h | 50 +- onnxruntime/core/graph/abi_graph_types.h | 22 +- onnxruntime/core/graph/ep_api_types.cc | 55 +- onnxruntime/core/graph/ep_api_types.h | 12 +- .../core/graph/model_editor_api_types.h | 14 +- onnxruntime/core/session/onnxruntime_c_api.cc | 30 +- onnxruntime/core/session/ort_apis.h | 7 +- onnxruntime/test/ep_graph/test_ep_graph.cc | 195 ++++- .../test/ep_graph/test_ep_graph_utils.cc | 1 + .../test/ep_graph/test_ep_graph_utils.h | 1 + 11 files changed, 1084 insertions(+), 21 deletions(-) create mode 100644 include/onnxruntime/core/providers/utils/ort_graph_to_proto.h diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h new file mode 100644 index 0000000000000..37665542f614f --- /dev/null +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -0,0 +1,718 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return Ort::Status{_status}; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims); +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // + // Set GraphProto metadata + // + const char* graph_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + size_t num_graph_inputs = 0; + size_t num_graph_outputs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); + + std::vector graph_inputs(num_graph_inputs); + std::vector graph_outputs(num_graph_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); + + for (const OrtValueInfo* ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + for (const OrtValueInfo* ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&ort_api, &value_infos, + &initializer_value_infos](const OrtValueInfo& ort_value_info, + /*out*/ const char** value_name_out = nullptr) -> Ort::Status { + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + + if (value_name_out != nullptr) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = false; + bool is_optional_graph_input = false; + bool is_graph_output = false; + bool is_constant_initializer = false; + bool is_from_outer_scope = false; + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, &ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, &ort_value_info); + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. + } + + return Ort::Status{nullptr}; + }; + + size_t num_nodes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + + std::vector nodes(num_nodes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + onnx::NodeProto* node_proto = graph_proto.add_node(); + + const char* node_name = nullptr; + const char* node_domain = nullptr; + const char* node_op_type = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + size_t num_inputs = 0; + size_t num_implicit_inputs = 0; + size_t num_outputs = 0; + size_t num_attrs = 0; + size_t num_subgraphs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); + + // Handle node attributes + if (num_attrs > 0) { + std::vector ort_attrs(num_attrs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); + + for (const OrtOpAttr* ort_attr : ort_attrs) { + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (!status.IsOK()) { + // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + } + } + + // Handle node subgraphs + if (num_subgraphs > 0) { + std::vector ort_subgraphs(num_subgraphs); + std::vector subgraph_attr_names(num_subgraphs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), + subgraph_attr_names.data())); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; + const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + } + + // Handle node inputs + if (num_inputs > 0) { + std::vector ort_inputs(num_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_inputs) { + if (ort_value_info == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_input(value_name); + } + } + + // Handle implicit inputs to this node. + if (num_implicit_inputs > 0) { + std::vector ort_implicit_inputs(num_implicit_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), + ort_implicit_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { + assert(ort_value_info != nullptr); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + } + } + + // Handle node outputs + if (num_outputs > 0) { + std::vector ort_outputs(num_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_outputs) { + if (ort_value_info == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_output(value_name); + } + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const std::pair& entry : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const std::pair& entry : initializer_value_infos) { + const OrtValueInfo* initializer_value_info = entry.second; + std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check that OrtGraph is a top-level graph (no parent node). + const OrtNode* parent_node = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + size_t num_operator_sets = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); + ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); + + std::vector domains(num_operator_sets, nullptr); + std::vector opset_versions(num_operator_sets); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), + num_operator_sets)); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (size_t i = 0; i < num_operator_sets; ++i) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(domains[i]); + operator_set->set_version(opset_versions[i]); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtTypeInfo* ort_type_info = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + + ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + + const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + + size_t num_dims = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + + std::vector ort_dims(num_dims, 0); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + + elem_type = ort_elem_type; + dims = std::move(ort_dims); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), + ort_dim_syms.size())); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms)); + + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + value_info_proto.set_name(value_name); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const char* attr_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); + attr_proto.set_name(attr_name); + + size_t total_attr_bytes = 0; + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + + int64_t i_val = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector i_vals(total_attr_bytes / sizeof(int64_t)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* ints = attr_proto.mutable_ints(); + for (int64_t val : i_vals) { + ints->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + + float f_val = 0.0f; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector f_vals(total_attr_bytes / sizeof(float)); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* floats = attr_proto.mutable_floats(); + for (float val : f_vals) { + floats->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::string* str = attr_proto.mutable_s(); + + str->resize(total_attr_bytes, '\0'); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, + &total_attr_bytes)); + + str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector chars(total_attr_bytes, '\0'); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* strs = attr_proto.mutable_strings(); + + // Strings are all in a single buffer, each separated with a '\0'. + // Extract each string and add it to the STRINGS attribute array. + char* at = chars.data(); + char* end = at + chars.size(); + + while (at < end) { + char* str_begin = at; + + while (*at && at < end) { + at++; + } + + strs->Add()->assign(str_begin, at - str_begin); + if (at < end) { + assert(*at == '\0'); + at++; // Skip '\0' to get to the beginning of the next string. + } + } + + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + return Ort::Status{nullptr}; +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 86c0b60db2bc4..bf1dd6e20ce64 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,6 +66,7 @@ extern "C" { #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) +#define _Out_writes_opt_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -4749,6 +4750,8 @@ struct OrtApi { * \param[in] len Number of bytes allowed to store in data * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success * + * \note Does not support reading graph attributes. Refer to Node_GetSubgraphs. + * * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); @@ -5568,6 +5571,45 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); + /** \brief Returns the number of operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_operator_sets Output parameter set to the number of operator sets that the graph's model uses. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); + + /** \brief Returns the operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] domains Pre-allocated array of `num_operator_sets` elements that is filled with + * null-terminated domain names. + * \param[out] opset_versions Pre-allocated array of `num_operator_sets` elements that is filled with + * the opset version of the corresponding domain in the `domains` array. + * \param[in] num_operator_sets The size of the `domains` and `opset_versions` arrays. + * Typical usage sets this to the result of Graph_GetNumOperatorSets(). + * An error status is returned if `num_operator_sets` is less than the actual number + * of operator sets. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); + /** \brief Returns the number of graph inputs. * * \note The count includes initializers that are included in the list of graph inputs. @@ -5933,20 +5975,24 @@ struct OrtApi { /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. ONNX nodes store subgraphs in + * their attributes, however, this function must be used to obtain subgraphs from an OrtNode. * * \param[in] node The OrtNode instance. * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. * \param[in] num_subgraphs The size of the `num_subgraphs` array. * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is * returned if `num_subgraphs` is less than the number of node subgraphs. + * \param[out] attribute_names Optional pre-allocated array of `num_subgraphs` elements that is filled with the + * attribute names that correspond to the subgraphs. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); /** \brief Get the node's parent OrtGraph instance. * diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index c3dd9321ebb0b..47fbe08da41ff 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -247,8 +247,11 @@ struct OrtNode { /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// /// Buffer into which to copy the subgraphs. + /// Optional buffer into which to copy the attribute name for each subgraph. + /// If set, must point to a buffer with the same number of elements as `subgraphs`. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. @@ -280,6 +283,23 @@ struct OrtGraph { /// The model's ONNX IR version. virtual int64_t GetOnnxIRVersion() const = 0; + /// + /// Gets the number of operator sets (domain, opset version) the graph's model relies on. + /// + /// Output parameter set to the number of operator sets. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumOperatorSets(size_t& num_operator_sets) const = 0; + + /// + /// Gets the operator sets the graph's model relies on. An operator set is uniquely identified by a + /// (domain, opset version) pair. + /// + /// Buffer into which to copy the domains. + /// Buffer into which to copy the opset version for each domain. + /// A status indicating success or an error. + virtual onnxruntime::Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const = 0; + /// /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 698c7422a1e2a..8583fac30cfbf 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -129,11 +129,12 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); - std::vector> node_subgraphs = node.GetSubgraphs(); - ep_node_subgraphs.reserve(node_subgraphs.size()); + std::unordered_map> subgraphs_map = node.GetAttributeNameToSubgraphMap(); + ep_node_subgraphs.reserve(subgraphs_map.size()); - for (gsl::not_null subgraph : node_subgraphs) { + for (const auto& [attr_name, subgraph] : subgraphs_map) { SubgraphState subgraph_state; + subgraph_state.attribute_name = attr_name; subgraph_state.subgraph_viewer = std::make_unique(*subgraph); ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); @@ -233,12 +234,17 @@ Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { return Status::OK(); } -Status EpNode::GetSubgraphs(gsl::span dst) const { +Status EpNode::GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const { const size_t num_subgraphs = subgraphs_.size(); - ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node subgraphs", num_subgraphs, subgraphs))); for (size_t i = 0; i < num_subgraphs; ++i) { - dst[i] = subgraphs_[i].ep_subgraph.get(); + subgraphs[i] = subgraphs_[i].ep_subgraph.get(); + + if (opt_attribute_names) { + opt_attribute_names[i] = subgraphs_[i].attribute_name.c_str(); + } } return Status::OK(); @@ -660,6 +666,43 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } +Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const { + num_operator_sets = graph_viewer_.DomainToVersionMap().size(); + return Status::OK(); +} + +Status EpGraph::GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const { + const std::unordered_map& domain_to_version = graph_viewer_.DomainToVersionMap(); + size_t num_operator_sets = domain_to_version.size(); + + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set domains", num_operator_sets, domains))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set versions", num_operator_sets, opset_versions))); + + // Collect (domain, version) pairs and sort them by domain to ensure user always gets a stable ordering. + std::vector> pairs; + pairs.reserve(num_operator_sets); + + for (const auto& [domain, version] : domain_to_version) { + pairs.emplace_back(domain.c_str(), version); + } + + std::sort(pairs.begin(), pairs.end(), + [](const std::pair& a, const std::pair& b) -> bool { + return std::strcmp(a.first, b.first) < 0; + }); + + // Copy sorted (domain, version) pairs into the destination buffers. + size_t index = 0; + for (const auto& [domain_c_str, version] : pairs) { + domains[index] = domain_c_str; + opset_versions[index] = version; + index++; + } + + return Status::OK(); +} + size_t EpGraph::GetNumInputs() const { return inputs_.size(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 4240f5636b7ae..12fa082d3f354 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -111,6 +111,7 @@ struct EpNode : public OrtNode { struct SubgraphState { SubgraphState() = default; SubgraphState(SubgraphState&& other) = default; + std::string attribute_name; std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. std::unique_ptr ep_subgraph; }; @@ -182,7 +183,8 @@ struct EpNode : public OrtNode { Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(gsl::span subgraphs) const override; + Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const override; // Gets this node's parent graph, which is the graph that directly contains this node. Status GetGraph(const OrtGraph*& parent_graph) const override; @@ -271,6 +273,14 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; + // Gets the number of operator sets that the graph's model uses. + Status GetNumOperatorSets(size_t& num_operator_sets) const override; + + // Gets the operator sets that the graph's model uses. An operator set is uniquely identified by a + // (domain, opset version) pair. + Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const override; + // Get the number of graph inputs, including initializers that are listed as graph inputs. size_t GetNumInputs() const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 6330a42c115db..6e7e17374bb59 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -136,7 +136,8 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetSubgraphs(gsl::span /*subgraphs*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/, + const char** /*opt_attribute_names*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } @@ -176,6 +177,17 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } + Status GetNumOperatorSets(size_t& /*num_operator_sets*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + + Status GetOperatorSets(gsl::span /*domains*/, + gsl::span /*opset_versions*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + size_t GetNumInputs() const override { return inputs.size(); } Status GetInputs(gsl::span /*result*/) const override { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index e7f60fd48a14f..18b545483b38b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2591,6 +2591,29 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets) { + API_IMPL_BEGIN + if (num_operator_sets == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_operator_sets' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNumOperatorSets(*num_operator_sets)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets) { + API_IMPL_BEGIN + gsl::span domains_span(domains, num_operator_sets); + gsl::span versions_span(opset_versions, num_operator_sets); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOperatorSets(domains_span, versions_span)); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN if (num_inputs == nullptr) { @@ -2922,10 +2945,11 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Ou } ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names) { API_IMPL_BEGIN gsl::span graphs_span(subgraphs, num_subgraphs); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span, attribute_names)); return nullptr; API_IMPL_END } @@ -3594,6 +3618,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, + &OrtApis::Graph_GetNumOperatorSets, + &OrtApis::Graph_GetOperatorSets, &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, &OrtApis::Graph_GetNumOutputs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index cbacbfce0740d..75db44cb9e9ff 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -631,6 +631,10 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); +ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); +ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); @@ -671,7 +675,8 @@ ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOp ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 60498e6510ec2..e9bed3ac45529 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include #include #include @@ -12,6 +14,9 @@ #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL +#include "core/providers/utils/ort_graph_to_proto.h" + #include "test/ep_graph/test_ep_graph_utils.h" #include "test/util/include/api_asserts.h" #include "test/util/include/asserts.h" @@ -68,6 +73,168 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1, 1, 28, 28}; + std::vector input_data(28 * 28, 0.5f); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'Input3' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("Input3"); + + // Run session and get outputs + std::array output_names{"Plus214_Output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 10); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_Mnist) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("mnist_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "mnist_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunMNISTModel(original_model_path, output_original); + RunMNISTModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + +static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'if_cond_input' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, &input_cond, 1, input_shape.data(), input_shape.size())); + ort_input_names.push_back("if_cond_input"); + + // Run session and get outputs + std::array output_names{"if_cond_output"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 1); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph to GraphProto. The model has 3 layers of nested subgraphs. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/three_layer_nested_subgraph.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("three_layer_nested_subgraph_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + { + Run3LayerModel(original_model_path, true, output_original); + Run3LayerModel(serialized_model_path, true, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } + + { + Run3LayerModel(original_model_path, false, output_original); + Run3LayerModel(serialized_model_path, false, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } +} + // // Utils for traversing an OrtGraph and checking against GraphViewer. // @@ -470,9 +637,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check node subgraphs - std::vector> node_subgraphs = node->GetSubgraphs(); + std::unordered_map> node_subgraphs_map = + node->GetAttributeNameToSubgraphMap(); - if (!node_subgraphs.empty()) { + if (!node_subgraphs_map.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); @@ -489,14 +657,27 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Recursively check subgraphs. size_t api_num_node_subgraphs = 0; ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); + ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); std::vector api_node_subgraphs(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); - - for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { - auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); - const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; + std::vector api_subgraph_attr_names(api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), + api_subgraph_attr_names.data())); + + for (const auto& [attr_name, subgraph] : node_subgraphs_map) { + // find index of this subgraph. + size_t api_subgraph_idx = api_num_node_subgraphs; + for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { + if (api_subgraph_attr_names[subgraph_idx] == attr_name) { + api_subgraph_idx = subgraph_idx; + break; + } + } + ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); + // Recursively check the subgraph + auto subgraph_viewer = std::make_unique(*subgraph); + const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index b7743e65061de..3b3bc4c6da911 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -30,6 +30,7 @@ std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } +const Model& TestGraph::GetModel() const { return *model; } static Status GetInputIndices(const Node& consumer_node, const std::string& name, /*out*/ std::vector& indices) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index b0ed825f21d71..2ce107cf734c6 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -28,6 +28,7 @@ class TestGraph { static std::unique_ptr Load(const ORTCHAR_T* model_path); const OrtGraph& GetOrtGraph() const; const GraphViewer& GetGraphViewer() const; + const Model& GetModel() const; private: std::shared_ptr model; From 581bb20fcb6ae65204b1435021289f04492f4ed0 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 9 Jul 2025 10:11:14 -0700 Subject: [PATCH 04/49] Update vcpkg.json: remove optional-lite (#25339) The library is not used. C++ itself already has std::optional. --- cmake/vcpkg.json | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 7c6b2fed36d1b..da179d0bad564 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -43,7 +43,6 @@ "ms-gsl", "nlohmann-json", "onnx", - "optional-lite", { "name": "protobuf", "version>=": "3.21.12" From a7178fd89f8b159dd6e161f665aa837ece15a69b Mon Sep 17 00:00:00 2001 From: Fei Chen Date: Thu, 10 Jul 2025 02:29:31 +0800 Subject: [PATCH 05/49] Move buffer release or cache from OnRefresh to ReleaseBuffer in BucketCacheManager (#25276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR is to move buffer release or cache from OnRefresh to ReleaseBuffer in BucketCacheManager. ### Motivation and Context The OnRefresh is executed after a batch(16) ep runs and inside the batch runs, the buffer can not be really reused which is a waste for gpu buffer resources. This PR proposed a strightforward optimization that release or cache the buffer early in ReleaseBuffer instead of OnRefresh to improve the buffer cache or release efficiency which will improve the peak and average GPU memory usage. The experimental result also shows a reasonable memory optimization without perf regressions. #### Phi3 Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 3603.83 | 3127.05 | 7.17 | 139.50 Default Bucket with Early Release Optimization | 3534.77 (+1.92%) | 3073.97 (+1.70%) | 7.14 (+0.36%) | 140.01 (+0.36%) #### Deepseek-R1 Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 2089.03 | 1716.15 | 6.07 | 164.67 Default Bucket with Early Release Optimization | 2034.00 (+2.63%) | 1674.49 (+2.43%) | 6.09 (-0.20%) | 164.34 (-0.20%) #### LLama3.2-1B Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 1736.03 | 1424.64 | 3.37 | 296.53 Default Bucket with Early Release Optimization | 1659.78 (+4.39%) | 1366.78 (+4.06%) | 3.41 (-1.09%) | 293.34 (-1.08%) --- .../core/providers/webgpu/buffer_manager.cc | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index e8140a4d59eab..113a3f31be7f9 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -193,27 +193,21 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(buffer); - } + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { - for (auto& buffer : pending_buffers_) { - auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - auto it = buckets_.find(buffer_size); - if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(buffer); - } else { - wgpuBufferRelease(buffer); - } + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); } + } - pending_buffers_.clear(); + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { + // no-op } ~BucketCacheManager() { - for (auto& buffer : pending_buffers_) { - wgpuBufferRelease(buffer); - } for (auto& pair : buckets_) { for (auto& buffer : pair.second) { wgpuBufferRelease(buffer); @@ -242,7 +236,6 @@ class BucketCacheManager : public IBufferCacheManager { } std::unordered_map buckets_limit_; std::unordered_map> buckets_; - std::vector pending_buffers_; std::vector buckets_keys_; }; From e17ec57cde682cff28f0b888da1727aa4fcf60f1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:29:43 -0700 Subject: [PATCH 06/49] [web] Fix "npm run pull:wasm" script (#25330) ### Description following up for #25267 --- js/web/script/pull-prebuilt-wasm-artifacts.ts | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index c3300f7272bb9..87008f51ff4b9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -38,7 +38,6 @@ Usage: Options: -d --debug specify the debug build type of the artifacts to download. -l --latest if set, will always use the latest build, even if it is not completed yet. - --webgpu-ep if set, will use the webgpu EP wasm build instead of the default(JSEP) one. -h --help print this message and exit `; @@ -81,9 +80,8 @@ try { // The following code checks both the command line arguments and the npm_config_* environment variables to get the correct values. const debug = args.debug || process.env.npm_config_d || process.env.npm_config_debug; const latest = args.latest || process.env.npm_config_l || process.env.npm_config_latest; -const webgpuEp = args['webgpu-ep'] || process.env.npm_config_webgpu_ep; -const folderName = (debug ? 'Debug_wasm' : 'Release_wasm') + (webgpuEp ? '_webgpu' : ''); +const folderName = debug ? 'Debug_wasm' : 'Release_wasm'; const allowImcomplete = latest; const run = args._[0]; // The first non-option argument @@ -151,13 +149,17 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { - // TODO: revise artifacts download - const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; - if (!folderName.endsWith('_webgpu')) { - filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); - } fs.readdirSync(WASM_FOLDER).forEach((file) => { - if (filesToDelete.includes(file)) { + if ( + [ + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.mjs', + 'ort-wasm-simd-threaded.wasm', + ].includes(file) + ) { const filePath = path.join(WASM_FOLDER, file); console.log(`Deleting old file: ${filePath}`); fs.unlinkSync(filePath); From b49fc62e01a14697e20ad73f212dd47c6a333c06 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 9 Jul 2025 13:20:37 -0700 Subject: [PATCH 07/49] [MLAS] DequantizeLinear int8/uint8 (#24818) ### Description - Adds multithreaded vectorized implementations of DequantizeLinear for int8 and uint8 inputs: - Intel SSE 2 - ARM NEON - All other architectures fallback to a multithreaded scalar reference implementation (previous was not multithreaded). - **Note**: only enabled if ORT is built for client/on-device workloads (`ORT_CLIENT_PACKAGE_BUILD` is defined). INT8 DequantizeLinear latency on Intel Core i9-10920X with 4 intra op threads (SSE 2 implementation) | Number of elements | Baseline latency (us) | Multithreaded+SIMD latency (us) | Speedup | | ----------------------- | ---------------------- | ------------------------------------ | ---------- | | 10 K | 1 | 1 | 1 | | 20 K | 2 | 2 | 1 | | 40 K | 5 | 5 | 1 | | 80 K | 11 | 4 | 2.75 | | 100 K | 14 | 5 | 2.80 | | 150 K | 21 | 7 | 3.00 | | 200 K | 28 | 8 | 3.50 | | 400 K | 68 | 15 | 4.53 | | 600 K | 107 | 21 | 5.10 | | 800 K | 142 | 28 | 5.07 | | 1 M | 187 | 42 | 4.45 | | 2 M | 376 | 102 | 3.69 | | 4 M | 880 | 236 | 3.73 | | 6 M | 1547 | 557 | 2.78 | | 8 M | 2438 | 1097 | 2.22 | | 10 M | 3192 | 1464 | 2.18 | | 100 M | 38718 | 17733 | 2.18 | INT8 DequantizeLinear latency on Snapdragon 8cx gen 3 @ 3.4GHz with 4 intra op threads (NEON implementation) | Number of elements | Baseline latency (us) | Multithreaded+SIMD latency (us) | Speedup | | ----------------------- | ---------------------- | ------------------------------------ | ---------- | | 10 K | 1 | 1 | 1 | | 20 K | 1 | 1 | 1 | | 40 K | 3 | 3 | 1 | | 80 K | 7 | 4 | 1.75 | | 100 K | 9 | 3 | 3.00 | | 150 K | 14 | 5 | 2.80 | | 200 K | 18 | 6 | 3.00 | | 400 K | 38 | 10 | 3.80 | | 600 K | 61 | 15 | 4.07 | | 800 K | 76 | 19 | 4.00 | | 1 M | 98 | 24 | 4.08 | | 2 M | 204 | 48 | 4.25 | | 4 M | 424 | 112 | 3.79 | | 6 M | 677 | 384 | 1.76 | | 8 M | 919 | 621 | 1.48 | | 10 M | 1132 | 776 | 1.46 | | 100 M | 11842 | 10566 | 1.12 | ### Motivation and Context Improves latency of quantized QDQ models that with large DQs that dominate the inference latency. --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/inc/mlas.h | 15 + onnxruntime/core/mlas/lib/dequantize.cpp | 395 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 22 + onnxruntime/core/mlas/lib/platform.cpp | 2 + .../cpu/quantization/quantize_linear.cc | 98 +++-- onnxruntime/core/util/qmath.h | 49 +++ .../mlas/unittest/test_dequantizelinear.cpp | 75 ++++ .../cpu/tensor/quantize_linear_test.cc | 26 ++ 9 files changed, 645 insertions(+), 38 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/dequantize.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f8f5546ae9465..47e7779d93b33 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,6 +31,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 3575e30721af7..91182a4ca9c44 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1223,6 +1223,21 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp new file mode 100644 index 0000000000000..175d3f668ac39 --- /dev/null +++ b/onnxruntime/core/mlas/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0af3cd2e33b02..0879d1b0ba510 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,6 +747,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -903,6 +921,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1246,6 +1266,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45d3a876beb86..45bba5363d4f2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,6 +285,8 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index adb2aee171f39..c691be6ffd0e8 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -301,14 +302,31 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point) { + const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { +#if defined(ORT_CLIENT_PACKAGE_BUILD) + // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. + // Make this the default behavior after more testing. + if constexpr (std::is_same_v || std::is_same_v) { + ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); + input += N; + output += N; + } else { + auto zp = zero_point ? static_cast(zero_point[k]) : 0; + auto sc = static_cast(scale[k]); + for (size_t n = 0; n < N; n++) { + *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); + } + } +#else + ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } +#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -327,7 +345,8 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -368,7 +387,8 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -394,7 +414,8 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; if (zero_point) { @@ -440,36 +461,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -513,6 +534,7 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -522,12 +544,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -537,12 +559,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 0172902bdf4e2..f7d5cdb98aa1d 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,4 +1001,53 @@ struct BlockedQuantizeLinear { #endif +/** + * @brief Run MlasDequantizeLinear in parallel, with provided thread pool + */ + +template +void ParDequantizeLinearStd(const InputQuantType* input, + float* output, + size_t num_elems, + float scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(float)), + static_cast(block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); + }); +} + +// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. +template +void ParDequantizeLinearStd(const InputQuantType* input, + MLFloat16* output, + size_t num_elems, + MLFloat16 scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(MLFloat16)), + static_cast(block_size) * 2.0}; + + const int32_t zp_s32 = static_cast(zero_point); + const float sc_f32 = scale.ToFloat(); + + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + for (; begin_idx != end_idx; ++begin_idx) { + output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); + } + }); +} + } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000000000..b994981364947 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 4e7a6356a5129..8fdbf0060eaa0 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,6 +33,32 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Uint8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + +// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Int8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); From cd5f91fe0624ad13047ea4dd46f08a35a020dc3a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 9 Jul 2025 14:19:28 -0700 Subject: [PATCH 08/49] [CPU] GQA supports head_sink input for smooth softmax (#25269) ### Description It is an extension of [Smooth Softmax](https://github.com/microsoft/onnxruntime/pull/21867) feature. The difference is that each head has a learnable smooth factor that adding to the denominator of softmax. The smooth factor is like an extra element that joins the softmax. The usage of the smooth factor in softmax is like the following: ```math softmax_{i} = \frac{exp(x_{i})}{exp(s)+ \sum_{j} exp(x_{j})} ``` The head_sink is a float tensor with length of number of attention heads. For h-th head, `head_sink[h]` is used as smooth factor s. When head_sink is not provided, constant 0 is used as smooth factor s. Changes: - [x] Update operator spec to add an optional new input `head_sink` - [x] Implement CPU (MLAS) kernel. - [x] Update test_gqa_cpu.py to test it. CUDA kernel will be updated later in a separate PR. --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 6 +- .../contrib_ops/cpu/bert/attention_helper.h | 6 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 11 +- .../cpu/bert/group_query_attention.cc | 4 +- .../core/graph/contrib_ops/bert_defs.cc | 5 + onnxruntime/core/mlas/inc/mlas.h | 1 + onnxruntime/core/mlas/lib/compute.cpp | 17 +- .../core/providers/cpu/math/softmax_shared.cc | 2 +- onnxruntime/core/providers/cpu/ml/ml_common.h | 2 +- .../test/mlas/bench/bench_computesoftmax.cpp | 4 +- .../test/mlas/unittest/test_softmax.cpp | 4 +- .../test/python/transformers/test_gqa_cpu.py | 251 ++++++++++++------ .../test/python/transformers/test_gqa_cuda.py | 3 +- .../transformers/test_paged_attention_cuda.py | 3 +- 15 files changed, 219 insertions(+), 104 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b80918e6615e1..9388e7e2a47cd 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2555,7 +2555,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 11) +#### Inputs (7 - 12)
query : T
@@ -2580,6 +2580,8 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+
head_sink (optional) : T
+
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ffcabee8cc10..e50702afe9975 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index ac32a4445f3ca..aef47edd5fcd2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, true, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { + MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index c79508cbae273..cec495ef7391e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -51,6 +51,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) @@ -97,7 +98,7 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -110,7 +111,7 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -136,6 +137,7 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH + const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention @@ -310,8 +312,9 @@ class GQAAttentionBase { } } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + if (use_smooth_softmax_ || head_sink != nullptr) { + float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); } else { ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a912bd6e6b43c..988151f778806 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -206,9 +206,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - attention_bias, past_key, past_value, output, present_k, present_v, + head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2757c2c96471..c2371487d9187 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1184,6 +1184,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) + .Input(11, + "head_sink", + "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 91182a4ca9c44..4d85c35461825 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,6 +1020,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 96a2398796777..669c73d2b9c06 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; + float Sink; const T* Input; T* Output; size_t N; @@ -850,6 +851,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -880,11 +882,12 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. @@ -897,7 +900,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -1014,6 +1017,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1039,6 +1043,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1060,6 +1066,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1097,6 +1104,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1110,6 +1118,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 2817dda9d0085..e123414b03b21 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 3359b2a69fe83..f7cc2523adbf6 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 65822eb294d7d..ea36383f70621 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 041b6c61cd5bf..4d7a45143b311 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 461c243b82212..ce0649e55f7c5 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -54,6 +54,7 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False @dataclass @@ -67,6 +68,7 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False # LLaMA Microsoft model @@ -166,6 +168,7 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -289,6 +292,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -380,6 +392,7 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -441,6 +454,7 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: graph_input += [ helper.make_tensor_value_info( @@ -462,6 +476,7 @@ def create_group_query_attention_graph_past( ], ), ] + if rotary: graph_input += [ helper.make_tensor_value_info( @@ -498,6 +513,15 @@ def create_group_query_attention_graph_past( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -552,17 +576,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + q: (batch_size, seqlen_q, num_heads, d) + k: (batch_size, seqlen_k, num_heads_k, d) + v: (batch_size, seqlen_k, num_heads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + batch_size, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, num_heads_k, d) + assert v.shape == (batch_size, seqlen_k, num_heads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -593,7 +617,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k + assert num_heads == num_heads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -714,6 +738,7 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -749,6 +774,11 @@ def gqa_prompt_func( if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -758,9 +788,6 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -797,25 +824,19 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -836,11 +857,16 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def gqa_past_func( @@ -855,6 +881,7 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -890,6 +917,11 @@ def gqa_past_func( if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -901,9 +933,7 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -940,11 +970,6 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -958,9 +983,7 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -988,11 +1011,16 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1025,11 +1053,28 @@ def construct_local_mask( ) -def smooth_softmax_ref(x): - x_max = x.amax(axis=-1, keepdim=True) - x_max = torch.maximum(x_max, torch.zeros_like(x_max)) - w = torch.exp(x - x_max) - return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) +def smooth_softmax_ref(x, head_sink): + """ + Arguments: + x: (batch_size, num_heads, seqlen_q, seqlen_k) + head_sink: (num_heads) or None + Output: + y: (batch_size, num_heads, seqlen_q, seqlen_k) + """ + assert len(x.shape) == 4 + b, n, s, t = x.shape + + if head_sink is not None: + assert len(head_sink.shape) == 1 + assert head_sink.shape[0] == x.shape[1] + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + y = y[..., :-1] + return y def attention_ref( @@ -1046,16 +1091,17 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) + q: (batch_size, seqlen_q, num_heads, head_dim) + k: (batch_size, seqlen_k, num_heads_k, head_dim) + v: (batch_size, seqlen_k, num_heads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1064,9 +1110,10 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not + head_sink: (num_heads) or None Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + output: (batch_size, seqlen_q, num_heads, head_dim) + attention: (batch_size, num_heads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) @@ -1098,8 +1145,8 @@ def attention_ref( ) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + if use_smooth_softmax or (head_sink is not None): + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) @@ -1133,6 +1180,7 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1146,6 +1194,7 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) @@ -1186,6 +1235,10 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids +def get_custom_head_sink(num_heads, torch_type=torch.float16): + return torch.rand(num_heads, dtype=torch_type) + + def parity_check_gqa_prompt( config, torch_type, @@ -1248,6 +1301,8 @@ def parity_check_gqa_prompt( requires_grad=False, ) + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + window_size = (-1, -1) left_window_size = -1 if local: @@ -1327,6 +1382,7 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1349,6 +1405,7 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, True, @@ -1371,6 +1428,7 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, True, @@ -1531,6 +1589,8 @@ def parity_check_gqa_prompt_no_buff( else None ) + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded @@ -1548,6 +1608,7 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1570,6 +1631,7 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, False, @@ -1592,6 +1654,7 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, False, @@ -1759,6 +1822,8 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1781,6 +1846,7 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1822,6 +1888,7 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, True, left_window_size, @@ -1844,6 +1911,7 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, True, left_window_size, @@ -1882,6 +1950,8 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, " B:", config.batch_size, " S:", @@ -2017,6 +2087,8 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2039,6 +2111,7 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2080,6 +2153,7 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, False, window_size=left_window_size, @@ -2102,6 +2176,7 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, False, window_size=left_window_size, @@ -2134,6 +2209,8 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2202,33 +2279,47 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - if config_class == PromptConfig: - config = config_class( - b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + for head_sink in [False, True]: + if use_smooth_softmax and head_sink: + continue + if config_class == PromptConfig: + config = config_class( + b, + s, + s2, + s + s2 + 8, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class( + b, s, s2, sp, n, n2, h, has_pos, has_attn, head_sink + ) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 2f5b638a57d0c..79976a92e54bf 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,7 +782,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index 410860a324a9d..ca5c9c2ce133f 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,7 +401,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) From 14e0ad7f637ad341e446e179a5aacea0b48b73bd Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 9 Jul 2025 15:14:01 -0700 Subject: [PATCH 09/49] Add PackageVersion parameter to NuGet packaging stage (#25315) Fix: `Microsoft.ML.OnnxRuntime.Managed.nupkg` artifact from GPU pipeline does not have package version. ![image](https://github.com/user-attachments/assets/4a6135ab-4774-4aa6-aeb1-d5b06948ba8f) --- .../azure-pipelines/stages/nuget-cuda-packaging-stage.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 84b6d30ee32ac..a87bb55441ac7 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -72,6 +72,8 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: ../templates/set-version-number-variables-step.yml + # Reconstruct the build dir - task: PowerShell@2 displayName: 'PS: Extract nuget files gpu' @@ -114,6 +116,7 @@ stages: -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: ../templates/win-esrp-dll.yml From d29328588e00bf2e578d904403fa5c6627754346 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Wed, 9 Jul 2025 20:17:13 -0700 Subject: [PATCH 10/49] [QNN EP] Fix pool with reshape name conflicts (#25332) Naming conflicts when expand-pool2d-squeeze (implemented as reshape) logic is invoked during ONNX -> QNN op lowering. Model with multiple pool 1D ops would hit this issue. --- .../qnn/builder/opbuilder/pool_op_builder.cc | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); } From ff815674fdbbe6b1b78309950c2dad9d49cf4e8f Mon Sep 17 00:00:00 2001 From: Akupadhye Date: Thu, 10 Jul 2025 23:38:16 +0530 Subject: [PATCH 11/49] Added creation of QDQ for TopK node (#25309) - Added TopK in registry.py so as to create QDQ nodes for the op - Ensure that both the input and output quantization params are equal - Added unit test to verify the creation of QDQ nodes for TopK ### Description: Added support for creation of QDQ nodes for TopK when quantized with ORT static quantization tool ### Motivation and Context: Currently there is support to form a node unit for TopK operator when QDQ nodes are present and both the input and output quantization params are equal. But there was no support to create QDQ nodes for TopK operator in the ORT static quantization tool --- .../python/tools/quantization/registry.py | 1 + .../test/python/quantization/test_op_topk.py | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 onnxruntime/test/python/quantization/test_op_topk.py diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index fbeae39c39d21..319c5aa468f7e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,6 +86,7 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, + "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py new file mode 100644 index 0000000000000..1fdd0c987d1e8 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_topk.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from onnx import TensorProto, helper, save +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestTopKModel(unittest.TestCase): + @staticmethod + def construct_model(model_path, input_shape, axis_attr, k): + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) + k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) + output_shape = input_shape[:] + output_shape[axis_attr] = k + output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) + output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) + + node = helper.make_node( + "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr + ) + + graph = helper.make_graph( + [node], + "quant_topk_op_test", + [input_tensor], + [output_values, output_indices], + initializer=[k_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] + ) + save(model, model_path) + + def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + model_fp32_path = "topk_fp32.onnx" + input_shape = [1, 10] + axis = 1 + k = 3 + self.construct_model(model_fp32_path, input_shape, axis, k) + + input_data_list = [ + {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} + ] + data_reader = TestDataFeeds(input_data_list) + + activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" + + # Verify QDQ mode + data_reader.rewind() + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = ( + { + "TopK": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 2, + } + if extra_options["ForceQuantizeNoInputCheck"] + else { + "TopK": 1, + "QuantizeLinear": 0, + "DequantizeLinear": 0, + } + ) + check_op_type_count(self, model_qdq_path, **qdqnode_counts) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def test_quantize_topk_u8u8(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) + + def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) + + +if __name__ == "__main__": + unittest.main() From d3820a25de02e24d72a9977aec28d205e9d99ee2 Mon Sep 17 00:00:00 2001 From: Wang Ning Date: Fri, 11 Jul 2025 02:32:59 +0800 Subject: [PATCH 12/49] [WebNN] Refactor webnn op input rank check and add validation for ops (#25185) ### Description Development for webnn op input rank range check ### Motivation and Context - refactor webnn op input rank check - add validation for various ops - take `gemm` op as an example to perform inputs rank check of decomposed ops @honry @fdwr PTAL --- .../core/providers/webnn/builders/helper.cc | 126 +++++++++++------- .../core/providers/webnn/builders/helper.h | 34 +++++ .../webnn/builders/impl/concat_op_builder.cc | 2 +- .../impl/gatherElements_op_builder.cc | 5 +- .../builders/impl/gatherND_op_builder.cc | 5 +- .../webnn/builders/impl/gather_op_builder.cc | 6 +- .../webnn/builders/impl/gemm_op_builder.cc | 44 +++++- .../webnn/builders/impl/gru_op_builder.cc | 2 +- .../webnn/builders/impl/logical_op_builder.cc | 2 +- .../webnn/builders/impl/lstm_op_builder.cc | 2 +- .../webnn/builders/impl/max_min_op_builder.cc | 2 +- .../webnn/builders/impl/qdq_op_builder.cc | 2 +- .../impl/scatterElements_op_builder.cc | 5 +- .../builders/impl/scatterND_op_builder.cc | 5 +- .../webnn/builders/impl/ternary_op_builder.cc | 2 +- .../core/providers/webnn/builders/map_info.h | 2 +- 16 files changed, 168 insertions(+), 78 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index e821265fff80d..142d64caa64aa 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if all input tensor ranks of the given node are supported by WebNN. -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view op_type = node.OpType(); - const auto it = op_inputs_map.find(op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; +// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger) { + const std::string webnn_op_type_str(webnn_op_type); + const std::string input_name_str(input_name); + + if (wnn_limits[webnn_op_type_str].isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type + << "] is not defined in WebNN MLOpSupportLimits."; return false; } - const auto& input_defs = node.InputDefs(); - const std::string_view webnn_op_type = it->second.opType; - const std::string webnn_op_type_str(webnn_op_type); + const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - for (const auto& input : it->second.inputs) { - if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { - LOGS(logger, VERBOSE) << "Input index [" << input.index - << "] for operator type [" << op_type - << "], corresponding WebNN op type [" << webnn_op_type - << "], WebNN input name [" << input.name - << "] is invalid."; - return false; - } + if (input_limits.isUndefined()) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "], WebNN op type: [" << webnn_op_type + << "], input [" << input_name + << "]: limits are not defined in WebNN MLOpSupportLimits."; + return false; + } - std::vector input_shape; - if (!GetShape(*input_defs[input.index], input_shape, logger)) { - return false; - } + const emscripten::val rank_range = input_limits["rankRange"]; + if (rank_range.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: missing 'rankRange' attribute."; + return false; + } - const std::string input_name_str(input.name); - if (wnn_limits[webnn_op_type_str].isUndefined() || - wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << " is not defined in wnn_limits."; - return false; - } + const emscripten::val min_val = rank_range["min"]; + const emscripten::val max_val = rank_range["max"]; + if (min_val.isUndefined() || max_val.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; + return false; + } - const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - if (input_limits["rankRange"].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << "'s rankRange is not defined."; - return false; + size_t min_rank = min_val.as(); + size_t max_rank = max_val.as(); + if (input_rank < min_rank || input_rank > max_rank) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "] WebNN op type [" << webnn_op_type + << "] input [" << input_name << "] rank " << input_rank + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + return false; + } + + return true; +} + +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view onnx_op_type = node.OpType(); + const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + + if (webnn_op_type.empty()) { + LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; + return false; + } + + std::vector inputs; + if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { + return false; + } + + const auto& input_defs = node.InputDefs(); + + for (const auto& input : inputs) { + // If it is an optional input and is absent, skip. + if (!TensorExists(input_defs, input.index)) { + continue; } - int input_dim_size = static_cast(input_shape.size()); - int min_rank = input_limits["rankRange"]["min"].as(); - int max_rank = input_limits["rankRange"]["max"].as(); - - if (input_dim_size < min_rank || input_dim_size > max_rank) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name: " << input.name - << ", input size " << input_dim_size - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + std::vector shape; + if (!GetShape(*input_defs[input.index], shape, logger) || + !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, + shape.size(), + node.Name(), logger)) { return false; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d59788600f997..50e361ede221e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger); + // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } +// Get corresponding input name of WebNN op type by ONNX op type from op_input_map +inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { + const auto it = op_inputs_map.find(onnx_op_type); + + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == input_index) { + return input.name; + } + } + } + + return ""; +} + +inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, + std::vector& inputs, + const logging::Logger& logger) { + const auto it = op_inputs_map.find(onnx_op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; + return false; + } + inputs = it->second.inputs; + return true; +} + bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 8589237617745..e0cd48b6883c2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 06beb56415609..b4b9d9a0d4c6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,13 +56,12 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index 9200c596c0e53..a15542061dd60 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,13 +61,12 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index d84c70032e1d1..86408557013a0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -74,13 +74,13 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type; - int32_t indices_type; + int32_t input_type, indices_type; + if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 02f46c85d1d06..7af17fdc5db78 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "MatMulInteger") { - // The first decomposed op of MatMulInteger is DequantizeLinear, and so - // we only need to ensure it supports the input0_type. + if (op_type == "Gemm") { + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } else if (op_type == "MatMulInteger") { + // Check up to 4 inputs for MatMulInteger + for (size_t i = 0; i < input_defs.size(); ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + // We made workaround to support 1D for input A and B, skip further checks if they are 1D + if (i <= 1 && shape.size() == 1) { + continue; + } + + // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) + if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", + (i < 2) ? "input" : "zeroPoint", + shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { + } else { // MatMul + for (int i = 0; i < 2; ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + if (shape.size() == 1) { + continue; + } + + if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index dfe80dd419092..6e86ca77464e5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 42940083cad8e..1675615280de9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -92,7 +92,7 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 09e584bc66f8a..fcdc84b75c048 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4e4014e3553ea..4d9cc39bd38fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -108,7 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index dd25fb9bf9315..eccf67cc46c9a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index f894e8bfbd517..ae3d559023625 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -85,7 +84,9 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index e61ac3dcc9617..5467e91761823 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -76,8 +75,8 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 7a7f64b1ec96d..5d6d59663da61 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,7 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 5e860eea7cac9..bf95527beb44e 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -139,7 +139,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "input"}}}}, + {"Not", {"logicalNot", {{0, "a"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, From 8a27eabb05ad6bd3319792c4c4b5b5dd61c7be65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Fri, 11 Jul 2025 03:10:12 +0800 Subject: [PATCH 13/49] Make TRT plugins optional (#25261) ### Description The parser does no longer link agains the plugin library but also loads it dynamic. Due to that I think we should also make the library optional in ORT. @chilo-ms --- cmake/onnxruntime_providers_tensorrt.cmake | 23 +++------- .../nv_tensorrt_rtx/nv_execution_provider.cc | 2 +- .../tensorrt_execution_provider_custom_ops.cc | 44 ++++++++++++++++++- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 69c81a5ec7b9d..4184e0b049afc 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,10 +72,9 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -83,15 +82,11 @@ set(NVINFER_LIB "nvinfer") endif() - if (NOT NVINFER_PLUGIN_LIB) - set(NVINFER_PLUGIN_LIB "nvinfer_plugin") - endif() - if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -101,14 +96,6 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() - find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - - if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) - MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") - endif() - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -120,7 +107,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -153,7 +140,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -161,7 +148,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 711d81186bad1..c5b6507ac847b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, CUDA_PINNED); + return std::make_unique(CUDA_PINNED, device_id); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 90a4294fb47f0..1e9fafe8aa323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,6 +7,25 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +#ifdef _WIN32 +#define ORT_DEF2STR_HELPER(x) L#x +#else +#define ORT_DEF2STR_HELPER(X) #X +#endif +#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) + namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -58,8 +77,31 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - initLibNvInferPlugins(&trt_logger, ""); + try { + void* library_handle = nullptr; + const auto& env = onnxruntime::GetDefaultEnv(); +#if NV_TENSORRT_MAJOR < 10 + auto full_path = env.GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); +#else +#ifdef _WIN32 + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); +#else + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); +#endif +#endif + + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); + if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; + } + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; + } catch (const std::exception&) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; + } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; From e6658c020a9accf2263c31909eb15147b9848b20 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:54:35 -0700 Subject: [PATCH 14/49] [EP ABI] Add Graph_GetGraphView API to get a OrtGraph from a subset of nodes (#25191) Added an API that creates a sub-graph from a set of nodes in an OrtGraph. This API is needed in the GetCapability EP ABI porting when EP wants to check whether a 'sub-graph' of the graph is supported by the hardware backend. --- include/onnxruntime/core/graph/graph.h | 5 +- .../core/session/onnxruntime_c_api.h | 18 ++++ onnxruntime/core/graph/ep_api_types.cc | 24 +++++ onnxruntime/core/graph/ep_api_types.h | 30 ++++++ onnxruntime/core/graph/graph.cc | 4 + onnxruntime/core/graph/graph_viewer.cc | 12 ++- onnxruntime/core/session/onnxruntime_c_api.cc | 86 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 2 + onnxruntime/test/ep_graph/test_ep_graph.cc | 59 ++++++++++++ .../three_layer_nested_subgraph_v2.onnx | Bin 0 -> 1892 bytes 10 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bf1dd6e20ce64..051a3f7283cbe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5748,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8583fac30cfbf..7f81ab3433911 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 12fa082d3f354..7b67f21bf4eb4 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 18b545483b38b..312ddd7e52e00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2714,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -3629,6 +3714,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 75db44cb9e9ff..b53863c02cfef 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index e9bed3ac45529..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -7,12 +7,15 @@ #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "core/providers/utils/ort_graph_to_proto.h" @@ -31,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d036541a70aa087f6007ec7261f5f1115b0e22f2 GIT binary patch literal 1892 zcmc&#&2G~`5Y8GWaV8~=whcrEAaP&c$9+r)V_e7+CMS24QOn};JS@5uuwhojaXfGV8^v_J5P zdou3)00^KKITpQ`lc^PqiAS-P$a?eD%nd@~hP}}dH(80r++4G?cA)rhT<3}SzV zo0H)NM;CKS-}5BRvOL4jGD9eH!WEX$*~psB!&{MZ1XACWRiwTs0x4Tok{~7J8<3Kg zyCC)P1w$%{yoS^cLrIz#W*xi{bu2O*#%bu4m&0LSw8Ff{j<~^0?WofhLE6E5aOx9p z+-hn{z1&rhvR6xj#l0Rpft7%`1{)f}8YmiKUuB|a&*yC0BB8Y#SE$(ftwJ>%Q#WDR zcU7`1t}VgNkwx6ZGU0g_>iSUC`&kVX?S*?o`uvGX(i5vu@IN| z?*bMeM{k4CleJI+1N)Ij+#zSHS&GlH!_In#AF`<{cM)O@maxVVTc1$c`~O>;pxRP( zIXcxvj{r1AK$R14@*wLUUe<5PZY?Vr@Ax{%IKP6((s~;_f@}%gISCP9`Mn8CLM$zz ztjLTTtJ|gos#d`TJ`=yt>P%cC=%)K$iEO=@Z0!RYCTsuD^;qlE&7WDoU|`u|{wi#u zIc1`bUgE^=dGRX1OxccXz73K+AWBcXbEQ8{)4^}*ur}5cKJ0ex&TT6IS7u(=7R%@O gpK*_GjBuRe!e9%~W$t+nclOVTCER-|6zZFQ0aGjCcK`qY literal 0 HcmV?d00001 From 591003b1ecd13e7862d655f91bed8fba27499cf6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:17:37 -0700 Subject: [PATCH 15/49] [webgpu] a few optimization to WGSL template (#25333) ### Description This change is a follow up to #25130. - consume duktape from vcpkg if --use_vcpkg is specified - ~~add a Windows CI pipeline for dynamic WGSL template~~ (Will do in a separate PR) - upgrade wgsl-template package from 0.1.10 to 0.1.13 - support adding contribop folder as input --- .../external/onnxruntime_external_deps.cmake | 25 +++++++++++++------ cmake/onnxruntime_providers_webgpu.cmake | 11 ++++---- cmake/vcpkg.json | 8 ++++++ .../webgpu/wgsl_templates/package-lock.json | 8 +++--- .../webgpu/wgsl_templates/package.json | 2 +- tools/ci_build/build.py | 2 ++ 6 files changed, 39 insertions(+), 17 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e8f6bbe895d29..228906030d14c 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,13 +774,24 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) + if(onnxruntime_USE_VCPKG) + find_package(unofficial-duktape CONFIG REQUIRED) + add_library(duktape_static ALIAS unofficial::duktape::duktape) + else() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) + + if(NOT TARGET duktape_static) + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) + target_include_directories(duktape_static INTERFACE $) + endif() + endif() endif() endif() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 5b80b1262464d..2865ad33b39f4 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,10 +172,12 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -207,10 +209,9 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) + # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index da179d0bad564..373ecec440921 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -93,6 +93,10 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] + }, + "webgpu-ep-wgsl-template-dynamic": { + "description": "Build with WebGPU EP with dynamic WGSL template code generator", + "dependencies": ["duktape"] } }, "overrides": [ @@ -103,6 +107,10 @@ { "name": "flatbuffers", "version": "23.5.26" + }, + { + "name": "duktape", + "version": "2.7.0#2" } ] } diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index 7cde6c17f54e9..df1940ed6416b 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", - "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", + "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 34831ccddeb33..246e7365531e0 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f6e37d33b2414..f864b8eb4a74d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,8 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") + if args.wgsl_template == "dynamic": + vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: From 57c9743e58edffa8351f23c5b99471432f01026d Mon Sep 17 00:00:00 2001 From: George Wu Date: Thu, 10 Jul 2025 18:52:03 -0700 Subject: [PATCH 16/49] add --client_package_build option (#25351) add a build option to enable default options more appropriate for client/on-device workloads. initial use case will be to set the default thread pool allow_spinning policy , which we want to default to 0/false for builds targeted for client/on-device workloads. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- cmake/CMakeLists.txt | 1 + cmake/adjust_global_compile_flags.cmake | 5 +++++ .../onnxruntime_session_options_config_keys.h | 4 +++- onnxruntime/core/session/inference_session.cc | 12 ++++++++++++ onnxruntime/core/util/thread_utils.h | 6 ++++++ tools/ci_build/build.py | 1 + tools/ci_build/build_args.py | 10 ++++++++++ .../github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- 9 files changed, 40 insertions(+), 3 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fb4238731ffc3..b01110b2a4a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -151,6 +151,7 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) +option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 59d99ade131cd..6d517003fa6b6 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -95,6 +95,11 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() +# ORT build with default settings more appropriate for client/on-device workloads. +if (onnxruntime_CLIENT_PACKAGE_BUILD) + add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD) +endif() + if (onnxruntime_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 97e53e6acee5a..314cf76cc8044 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -148,7 +148,9 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking // "0": thread will block if found no job to run -// "1": default, thread will spin a number of times before blocking +// "1": thread will spin a number of times before blocking +// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise. +// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency. static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 86a61a4d0ee74..f147242da668f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -423,7 +423,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, { if (!external_intra_op_thread_pool_) { bool allow_intra_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1"; +#else + // default KOrtSessionOptionsConfigAllowIntraOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.intra_op_param; std::basic_stringstream ss; if (to.name) { @@ -461,7 +467,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { if (!external_inter_op_thread_pool_) { bool allow_inter_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1"; +#else + // default kOrtSessionOptionsConfigAllowInterOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.inter_op_param; to.auto_set_affinity = to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; std::basic_stringstream ss; diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index d63d620dbc321..0b99723b2c75b 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -19,7 +19,13 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. +#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; +#else + // default allow_spinning to false for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + bool allow_spinning = false; +#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f864b8eb4a74d..893f3c80fa4b8 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -472,6 +472,7 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), + "-Donnxruntime_CLIENT_PACKAGE_BUILD=" + ("ON" if args.client_package_build else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index ad27b8124c458..53d53f3e15e99 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -527,6 +527,15 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: ) +def add_client_package_args(parser: argparse.ArgumentParser) -> None: + """Adds arguments for client package build package.""" + parser.add_argument( + "--client_package_build", + action="store_true", + help="Create ORT package with default settings more appropriate for client/on-device workloads.", + ) + + def add_python_binding_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for Python bindings.""" parser.add_argument("--enable_pybind", action="store_true", help="Enable Python bindings.") @@ -833,6 +842,7 @@ def convert_arg_line_to_args(self, arg_line: str) -> list[str]: # Use list[str] add_dependency_args(parser) add_extension_args(parser) add_size_reduction_args(parser) + add_client_package_args(parser) # Language Bindings add_python_binding_args(parser) diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 1406ce338f13e..b600341827aad 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 78fce1f9b9602..66df2d6053d51 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -50,7 +50,7 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - ExtraQnnBuildArgs: '' + ExtraQnnBuildArgs: '--client_package_build' STATIC_LIB: QnnLibKind: 'static_lib' ExtraQnnBuildArgs: '' From fb0f6c652be5db0a3182c424a995efecf792d41c Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 11 Jul 2025 09:57:18 +0800 Subject: [PATCH 17/49] [WebNN] Fix bug in Float16Array availability check (#25354) The `from` is not a property of `Float16Array` but an inherited function, we can use `Float16Array['from']` to check if it is available. --- .../providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc | 2 +- onnxruntime/core/providers/webnn/builders/model_builder.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 893ca9d2419c7..37071b1030e11 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -285,7 +285,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is avaliable - use Float16Array. + // Float16Array is available - use Float16Array. sign_buffer = emscripten::val::global("Float16Array").new_(2); sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4468831181d42..d2cd0639affd0 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -78,7 +78,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - emscripten::val::global("Float16Array").hasOwnProperty("from"); + !emscripten::val::global("Float16Array")["from"].isUndefined(); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); From 9fc41c3179c63c0a1239f7e163f009f0ba14e307 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:42:04 -0700 Subject: [PATCH 18/49] [EP ABI] Add Node_GetEpType API (#25350) Add a new API `Node_GetEpType` to get the EP that the node is assigned to run on. This API is needed when porting the plugin TRT EP in `GetCapability` where ep needs to know whether the subgraph(s) of the control flow node is assigned to the ep and then to add this control flow op to the support list. --- .../core/session/onnxruntime_c_api.h | 12 ++++++++++++ onnxruntime/core/graph/ep_api_types.cc | 4 ++++ onnxruntime/core/graph/ep_api_types.h | 3 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 18 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 1 + onnxruntime/test/autoep/library/ep.cc | 6 ++++++ 6 files changed, 44 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 051a3f7283cbe..9172965e18fcf 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6026,6 +6026,18 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); + /** \brief Returns the execution provider type (name) that this node is assigned to run on. + * Returns NULL if the node has not been assigned to any execution provider yet. + * + * \param[in] node The OrtNode instance. + * \param[out] out Output execution provider type and can be NULL if node has not been assigned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); + /// @} /// \name OrtRunOptions diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 7f81ab3433911..073c6a2c743eb 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -276,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } +const std::string& EpNode::GetEpType() const { + return node_.GetExecutionProviderType(); +} + // // EpValueInfo // diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 7b67f21bf4eb4..1acbcc478a99b 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -208,6 +208,9 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the execution provider that this node is assigned to run on. + const std::string& GetEpType() const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 312ddd7e52e00..55bc28cd7139f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3052,6 +3052,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const char** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpType."); + } + + *out = ep_node->GetEpType().c_str(); + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3734,6 +3751,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, + &OrtApis::Node_GetEpType, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b53863c02cfef..fed7009828999 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -680,6 +680,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); +ORT_API_STATUS_IMPL(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index b498c40079f48..a5b46c74ecc21 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + const char* ep_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetEpType(fused_nodes[0], &ep_type)); + if (std::strncmp(ep_type, "example_ep", 11) != 0) { + return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); + } + // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); From cee25ba2d627f9d17f5e9768c89cc54f33b6eaa8 Mon Sep 17 00:00:00 2001 From: quic-calvnguy Date: Thu, 10 Jul 2025 22:10:07 -0700 Subject: [PATCH 19/49] QNN-EP: DSPQueue Polling (#25361) ### Description Enable DSP queue polling when performance profile is burst --- .../qnn/builder/qnn_backend_manager.cc | 33 ++++++++++++------- .../qnn/builder/qnn_backend_manager.h | 5 +-- .../providers/qnn/qnn_execution_provider.cc | 25 +++++++++----- .../providers/qnn/qnn_execution_provider.h | 4 ++- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index d22edaf33eb1c..98a078ff3eb87 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1426,13 +1426,33 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, return Status::OK(); } -Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency) { +Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time) { // This function is called in QNN EP's OnRunStart() even if QNN backend setup failed and the model is assigned // to a different EP. Therefore, we have to check that backend setup actually completed before trying to // set RPC control latency. Otherwise, this causes a segfault because the QNN backend library is unloaded. ORT_RETURN_IF_NOT(backend_setup_completed_, "Cannot set HTP RPC control latency if backend setup is not complete."); + + constexpr int kNumRpcPollingPowerConfigs = 2; + std::vector rpc_power_configs; + rpc_power_configs.reserve(kNumRpcPollingPowerConfigs); + + // Set rpc control latency here if (rpc_control_latency != 0) { + auto& rpc_control_latency_cfg = rpc_power_configs.emplace_back(); + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + } + + // Note: v68 does not support rpc polling mode + if (rpc_polling_time != 0) { + auto& rpc_polling_time_cfg = rpc_power_configs.emplace_back(); + rpc_polling_time_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; + rpc_polling_time_cfg.rpcPollingTimeConfig = rpc_polling_time; + } + + if (rpc_power_configs.size() > 0) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -1442,15 +1462,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - constexpr int kNumRpcPollingPowerConfigs = 2; - std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; - // v68 doesn't support this. - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; - rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 3e68df3024565..4cc805bcff0c9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -159,8 +159,9 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, HtpPerformanceMode htp_performance_mode); - Status SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency); + Status SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 236447cc95c3d..aeaaa1df9c0df 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1356,7 +1356,8 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency) + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time) : qnn_backend_manager_(qnn_backend_manager) { Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); is_htp_power_config_id_valid_ = rt.IsOK(); @@ -1367,9 +1368,10 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, default_htp_performance_mode)); } - if (default_rpc_control_latency > 0) { - ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, - default_rpc_control_latency)); + if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_, + default_rpc_control_latency, + default_rpc_polling_time)); } } } @@ -1400,7 +1402,8 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex if (context_state_.retired_context_pool.empty()) { uint32_t core_id = 0; context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, - default_htp_performance_mode_, default_rpc_control_latency_); + default_htp_performance_mode_, default_rpc_control_latency_, + default_rpc_polling_time_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1468,15 +1471,21 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } + uint32_t rpc_polling_time = 0; + if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) { + rpc_polling_time = 9999; + } + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), htp_performance_mode)); } - if (rpc_control_latency > 0) { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), - rpc_control_latency)); + if (rpc_control_latency > 0 || rpc_polling_time > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency, + rpc_polling_time)); } } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 06f9726ae96cf..9dcfa4c1291c8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -96,6 +96,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; + uint32_t default_rpc_polling_time_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; @@ -116,7 +117,8 @@ class QNNExecutionProvider : public IExecutionProvider { PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency); + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time); ~PerThreadContext(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); From 56078fe90725158f6ddbfdd21b674d56b7739452 Mon Sep 17 00:00:00 2001 From: quic-calvnguy Date: Fri, 11 Jul 2025 08:35:43 -0700 Subject: [PATCH 20/49] [QNN_EP] Implement Efficient Mode API (#25146) ### Description - Set context priority to low when workload type is Efficient - Set context priority to command line configured value if Default - Error out otherwise (invalid argument) --- .../qnn/builder/qnn_backend_manager.cc | 17 +++++ .../qnn/builder/qnn_backend_manager.h | 5 ++ .../providers/qnn/qnn_execution_provider.cc | 34 +++++++++ .../providers/qnn/qnn_execution_provider.h | 3 + .../test/providers/qnn/qnn_ep_context_test.cc | 69 ++++++++++++++++++- 5 files changed, 126 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 98a078ff3eb87..3dc103046424e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,6 +839,23 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } +Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { + QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); + + QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; + for (const auto& context_handle : contexts_) { + auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); + } + + return Status::OK(); +} + +Status QnnBackendManager::ResetContextPriority() { + return SetContextPriority(context_priority_); +} + Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 4cc805bcff0c9..2a71c7391b180 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -220,6 +220,11 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); + // Sets the context priority to the given value, if valid + Status SetContextPriority(ContextPriority context_priority); + // Resets the context priority to the session default as defined by context_priority_ + Status ResetContextPriority(); + private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index aeaaa1df9c0df..3acb3347acee1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1554,4 +1554,38 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } +Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + if (keys.size() != values.size()) { + LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() + << ") does not equal number of values (" << values.size() << ")."; + } + auto key_it = keys.begin(); + auto value_it = values.begin(); + + while (key_it != keys.end() && value_it != values.end()) { + std::string key(*key_it); + std::string value(*value_it); + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Default") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); + } else if (value == "Efficient") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); + } else { + LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); + } + } else { + LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); + } + + key_it++; + value_it++; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 9dcfa4c1291c8..6adf613932d66 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,6 +57,9 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + Status SetEpDynamicOptions(gsl::span keys, + gsl::span value) override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 4febfe7ba836d..3335c242112ab 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -1649,7 +1649,6 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } -#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1669,7 +1668,6 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } -#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -1994,6 +1992,73 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } + +TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); + + std::vector input_names; + std::vector output_names; + GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{3, 4}; + std::vector input_value(3 * 4, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + const char* const workload_type[] = {"ep.dynamic.workload_type"}; + const char* const efficient_type[] = {"Efficient"}; + const char* const default_type[] = {"Default"}; + + // Test Efficient & Default options + session.SetEpDynamicOptions(workload_type, efficient_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(workload_type, default_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + // Test invalid EP dynamic option and invalid workload type + const char* const dne[] = {"DNE"}; + try { + session.SetEpDynamicOptions(workload_type, dne, 1); + FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Invalid EP Workload Type.", e.what()); + } + + try { + session.SetEpDynamicOptions(dne, efficient_type, 1); + FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); + } +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 56a93a07bdecaba2118c764d56b79743df7e805d Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 11 Jul 2025 10:48:01 -0700 Subject: [PATCH 21/49] Add Compile API to set the location for the context binary file (#25356) Add Compile API ModelCompilationOptions_SetEpContextBinaryInformation to set the folder path and model name so that the EP can get the right place to dump the [model_name]_[ep].bin file. --- .../core/session/onnxruntime_c_api.h | 18 ++++++++++ .../core/session/onnxruntime_cxx_api.h | 2 ++ .../core/session/onnxruntime_cxx_inline.h | 9 +++++ onnxruntime/core/session/compile_api.cc | 30 ++++++++++++++++ onnxruntime/core/session/compile_api.h | 2 ++ .../core/session/model_compilation_options.cc | 36 +++++++++++++++++-- .../core/session/model_compilation_options.h | 10 ++++++ .../test/providers/qnn/qnn_ep_context_test.cc | 11 ++++++ 8 files changed, 115 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9172965e18fcf..5f4e927b901ba 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6886,6 +6886,24 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, size_t flags); + + /** Sets information related to EP context binary file. + * + * EP uses this information to decide the location and context binary file name. + * Used while compiling model with input and output in memory buffer + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise). + * \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_directory, + _In_ const ORTCHAR_T* model_name); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c59baa59c91a5..d1b08f127fa2a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1161,6 +1161,8 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 612adc81d3309..ba5d53e6c2dd0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -819,6 +819,15 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation( + const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation( + this->p_, + output_directory, + model_name)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( const ORTCHAR_T* file_path, size_t initializer_size_threshold) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index d910e3ea74b57..59b0992d827e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -128,6 +128,35 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + std::string output_dir = PathToUTF8String(output_directory); + if (output_dir.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); + } + + std::string model_name_str = ToUTF8String(model_name); + if (model_name_str.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(output_directory); + ORT_UNUSED_PARAMETER(model_name); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* external_initializers_file_path, @@ -248,6 +277,7 @@ static constexpr OrtCompileApi ort_compile_api = { // End of Version 22 - DO NOT MODIFY ABOVE &OrtCompileAPI::ModelCompilationOptions_SetFlags, + &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 5f11b894f2004..93cc5dbf20fce 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,5 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, size_t flags); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 5de0f03fafc08..bbb110033f54c 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() - << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." - << "ORT will still generated the expected output file, but EPs will see an empty " + << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; } } @@ -98,6 +98,36 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, + const std::string& model_name) { + if (output_directory.empty() || model_name.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); + } + + std::filesystem::path output_dir_path(output_directory); + if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); + } + + std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + + if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, + ctx_model_path.string().c_str())); + } else { + logging::LoggingManager* log_manager = env_.GetLoggingManager(); + if (log_manager != nullptr && log_manager->HasDefaultLogger()) { + const logging::Logger& logger = log_manager->DefaultLogger(); + LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of " + << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " + << "output path in SessionOption's ConfigOptions."; + } + } + + return Status::OK(); +} + Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -146,7 +176,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { ep_context_gen_options.output_model_buffer_ptr = nullptr; ep_context_gen_options.output_model_buffer_size_ptr = nullptr; ep_context_gen_options.output_model_buffer_allocator = nullptr; - return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); + return Status::OK(); } Status ModelCompilationOptions::CheckInputModelSettings() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index f96f0317cdaca..2824df863013d 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -72,6 +72,16 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets information relate to EP context binary file. + /// EP use this information to decide the location and context binary file name. + /// Used while compiling model with input and output in memory buffer + /// + /// The folder path to the generated context binary file + /// Model name used to decide the context binary file name: [model_name]_[ep].bin + /// Status indicating potential error + Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 3335c242112ab..739e39a6975e2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -509,6 +509,11 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); + std::string target_dir = "./testdata/"; + std::string model_name = "test_model_in_mem.onnx"; + auto pos = model_name.rfind(".onnx"); + std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; + compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); // Compile the model. @@ -519,12 +524,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB ASSERT_TRUE(output_model_buffer != nullptr); ASSERT_TRUE(output_model_buffer_size > 0); + ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist"; + // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + // Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file + std::string ctx_model = target_dir + model_name; + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str()); // Should be able to create a session with the compiled model and the original session options. EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + std::filesystem::remove(target_dir + bin_file_name); allocator.Free(output_model_buffer); } } From 0d6e2d994330bf187970f9e4fbd2caa8de2f778d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:56:40 -0700 Subject: [PATCH 22/49] add build matrix for wgsl template (#25352) ### Description Windows WebGPU CI: add build matrix for wgsl template --- .github/workflows/windows_webgpu.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 70e8ea7e2792f..996e0d816d51a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -22,6 +22,7 @@ jobs: strategy: matrix: vcpkg_option: [novcpkg, vcpkg] + wgsl_template: [static, dynamic] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -123,6 +124,7 @@ jobs: --build_nodejs ` --build_java ` --use_webgpu ` + --wgsl_template ${{ matrix.wgsl_template }} ` ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` --cmake_extra_defines ` onnxruntime_BUILD_UNIT_TESTS=ON ` From a532c8aee77894454329e22674c8be8a93a440c1 Mon Sep 17 00:00:00 2001 From: Jie Chen Date: Sat, 12 Jul 2025 04:21:16 +0800 Subject: [PATCH 23/49] [JSEP] Fix inputShape index OOB in slice.ts (#25364) Use `inputShape.length - 1` instead of `inputShape.length` to avoid out-of-bounds access. --- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5a837fd1e0bfa..c2085342efd80 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -98,7 +98,7 @@ const calculateInputIndicesImpl = ( `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; - for (var i = ${inputShape.length}; i >= 0; i--) { + for (var i = ${inputShape.length - 1}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; From 6ef13e3a7fba7fa03bd7b8b5b49dc177c5884a9a Mon Sep 17 00:00:00 2001 From: xhcao Date: Sat, 12 Jul 2025 04:27:46 +0800 Subject: [PATCH 24/49] [webgpu] extend cast version to 23 (#25235) --- .../core/providers/webgpu/tensor/cast.cc | 20 ++++++++++++++++++- .../webgpu/webgpu_execution_provider.cc | 8 ++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 7f92ea4ed3776..313a96ba25509 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -52,10 +52,28 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 19, + 23, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 460d220ecf1b9..6e09f494f4a8d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -123,7 +123,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); @@ -455,7 +457,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), + KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast), + KERNEL_CREATE_INFO(23, Cast), // // activations BuildKernelCreateInfo, From 47b378dc25cc9019c1061f7a2e7a83ae59b72608 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 11 Jul 2025 13:31:48 -0700 Subject: [PATCH 25/49] Fix a security warning (#18979) Description (reference: https://github.com/advisories/GHSA-5crp-9r3c-p9vr) Newtonsoft.Json prior to version 13.0.1 is vulnerable to Insecure Defaults due to improper handling of expressions with high nesting level that lead to StackOverFlow exception or high CPU and RAM usage. Exploiting this vulnerability results in Denial Of Service (DoS). To mitigate the issue one either need to update Newtonsoft.Json to 13.0.1 or set MaxDepth parameter in the JsonSerializerSettings. ``` JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; ``` This file is the only place using `JsonConvert`, so I blindly put this fix and hope the warning will disappear. --- .../EndToEndTests.Mobile.Automation/Tests.cs | 4 +++- .../TestResultProcessor.cs | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs index c28830ec72157..6e6190b8227b8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs @@ -40,10 +40,12 @@ public void RunPlatformUnitTest() var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString(); Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned"); + // Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var testSummary = JsonConvert.DeserializeObject(serializedResultSummary); Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed"); _app.Screenshot("Post-testing"); } } -} \ No newline at end of file +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs index 8419d261e4a41..625cc2c54055c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs @@ -45,8 +45,9 @@ public TestResultSummary GetResults() public string GetSerializedResults() { var resultSummary = GetResults(); + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented); return serializedResultSummary; } } -} \ No newline at end of file +} From a2b7a48853c4d1544ed75ff80513ee240ea11a62 Mon Sep 17 00:00:00 2001 From: quic-hungjuiw Date: Sat, 12 Jul 2025 04:33:01 +0800 Subject: [PATCH 26/49] Fix AutoEpSelection and OrtEpLibrary tests when using AuthenticAMD (#24754) --- onnxruntime/core/common/cpuid_info.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 8ea593f107833..c4667d53c0674 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -170,7 +170,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "GenuineAMD") return 0x1022; + if (vendor == "AuthenticAMD") return 0x1022; if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; From 4ac95cee1d94e58e4f7ab569f4dbafe8ba1d26bd Mon Sep 17 00:00:00 2001 From: Ian Hunter Date: Fri, 11 Jul 2025 21:33:16 +0100 Subject: [PATCH 27/49] Missing datatype in assertion (#23578) --- onnxruntime/python/tools/quantization/base_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 9a297e451213a..e3303dac6c8c5 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -42,7 +42,7 @@ def __init__(self, **data: dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if k != "axis" and not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") if k == "axis" and not isinstance(v, int) and v is not None: raise TypeError(f"Axis value must be an int or None, not {type(v)}.") From aa644e8cbac8eceacbac1f4f10e5ce6a37b115c0 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Fri, 11 Jul 2025 21:20:55 -0700 Subject: [PATCH 28/49] [EP ABI] Update to use Node_GetEpName (#25363) Change to use `Node_GetEpName` API name to avoid confusion. For plugin EPs, the EP factory can use whatever name that registered with ORT, so make the API name `Node_GetEpName` to align with `OrtEpFactory.GetName.` --- include/onnxruntime/core/session/onnxruntime_c_api.h | 5 +++-- onnxruntime/core/graph/ep_api_types.cc | 2 +- onnxruntime/core/graph/ep_api_types.h | 4 ++-- onnxruntime/core/session/onnxruntime_c_api.cc | 8 ++++---- onnxruntime/core/session/ort_apis.h | 2 +- onnxruntime/test/autoep/library/ep.cc | 6 +++--- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5f4e927b901ba..82e782112974f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6026,8 +6026,9 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); - /** \brief Returns the execution provider type (name) that this node is assigned to run on. + /** \brief Returns the execution provider name that this node is assigned to run on. * Returns NULL if the node has not been assigned to any execution provider yet. + * For plugin execution providers, the name is the one returned by OrtEp::GetName. * * \param[in] node The OrtNode instance. * \param[out] out Output execution provider type and can be NULL if node has not been assigned. @@ -6036,7 +6037,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); + ORT_API2_STATUS(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); /// @} diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 073c6a2c743eb..f57543416a68f 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -276,7 +276,7 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } -const std::string& EpNode::GetEpType() const { +const std::string& EpNode::GetEpName() const { return node_.GetExecutionProviderType(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 1acbcc478a99b..d3921e051e18a 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -208,8 +208,8 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; - // Helper that gets the execution provider that this node is assigned to run on. - const std::string& GetEpType() const; + // Helper that gets the execution provider name that this node is assigned to run on. + const std::string& GetEpName() const; private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 55bc28cd7139f..db2a62c77d1bc 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3052,7 +3052,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node, +ORT_API_STATUS_IMPL(OrtApis::Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out) { API_IMPL_BEGIN if (out == nullptr) { @@ -3061,10 +3061,10 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node, const EpNode* ep_node = EpNode::ToInternal(node); if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpType."); + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpName."); } - *out = ep_node->GetEpType().c_str(); + *out = ep_node->GetEpName().c_str(); return nullptr; API_IMPL_END } @@ -3751,7 +3751,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, - &OrtApis::Node_GetEpType, + &OrtApis::Node_GetEpName, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fed7009828999..9ab927006c320 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -680,7 +680,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); -ORT_API_STATUS_IMPL(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); +ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index a5b46c74ecc21..78261162ebaf8 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -328,9 +328,9 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); - const char* ep_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpType(fused_nodes[0], &ep_type)); - if (std::strncmp(ep_type, "example_ep", 11) != 0) { + const char* ep_name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); + if (std::strncmp(ep_name, "example_ep", 11) != 0) { return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); } From 2b8c555837fed2b09ea084cede9d4d9350c9ff9d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Jul 2025 05:54:16 +0000 Subject: [PATCH 29/49] Bump clang-format from 20.1.7 to 20.1.8 (#25381) --- requirements-lintrunner.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 450b955f161af..5bfe909f97aba 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,4 +3,4 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 ruff==0.12.2 -clang-format==20.1.7 +clang-format==20.1.8 From 613d22dadf9be8ceeb6ebbd631642b1abb6db2f2 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Mon, 14 Jul 2025 09:07:15 -0700 Subject: [PATCH 30/49] Fix number of layers in Whisper export (#25375) ### Description This PR fixes the number of hidden layers used during the export of Whisper by always using the number of hidden layers in the decoder. ### Motivation and Context Most of the Whisper models contain the same number of hidden layers in the encoder and decoder. However, Whisper large v3 turbo contains 32 hidden layers in the encoder and only 4 hidden layers in the decoder. This PR also fixes [this issue](https://github.com/microsoft/onnxruntime-genai/issues/1611). --- .../tools/transformers/models/whisper/convert_to_onnx.py | 2 +- .../tools/transformers/models/whisper/requirements.txt | 2 +- .../tools/transformers/models/whisper/whisper_decoder.py | 7 +++---- .../models/whisper/whisper_encoder_decoder_init.py | 4 ++-- .../tools/transformers/models/whisper/whisper_helper.py | 4 ++-- .../tools/transformers/models/whisper/whisper_inputs.py | 6 +++--- .../transformers/models/whisper/whisper_jump_times.py | 2 +- 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index ac696ff3788aa..e092285d57358 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.num_hidden_layers, + model.config.decoder_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index f1758cc52280f..37fc72cd26e07 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers>=4.52.3 +transformers==4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index fadf271ae913b..e10e616d35d38 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -214,8 +214,7 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") - for i in range(self.config.num_hidden_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 26dc3aee7018b..cd81edc1001be 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f66aa22eb0972..a236c4da1738e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_layers: int, + num_decoder_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 0b0882eface72..8937fea900d14 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index a7c0d3538b8da..4dd5d7de1752b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], + *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], ] return input_names From 491f435301b310ed2492422ba29742e30e0ad2e7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Jul 2025 13:45:54 -0700 Subject: [PATCH 31/49] Bump transformers from 4.48.0 to 4.52.1 in /onnxruntime/python/tools/transformers/models/llama (#25328) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [transformers](https://github.com/huggingface/transformers) from 4.48.0 to 4.52.1.
Release notes

Sourced from transformers's releases.

Patch release v4.51.3

A mix of bugs were fixed in this patch; very exceptionally, we diverge from semantic versioning to merge GLM-4 in this patch release.

  • Handle torch ver in flexattn (#37400)
  • handle torch version edge cases (#37399)
  • Add glm4 (#37388)

Patch Release 4.51.2

This is another round of bug fixes, but they are a lot more minor and outputs were not really affected!

Patch release v4.51.1

Since the release of Llama 4, we have fixed a few issues that we are now releasing in patch v4.51.1

  • Fixing flex attention for torch=2.6.0 (#37285)
  • more fixes for post-training llama4 (#37329)
  • Remove HQQ from caching allocator warmup (#37347)
  • fix derived berts _init_weights (#37341)
  • Fix init empty weights without accelerate (#37337)
  • Fix deepspeed with quantization (#37324)
  • fix llama4 training (#37319)
  • fix flex attn when optional args aren't passed (#37327)
  • Multiple llama4 fixe (#37353)

Thanks all for your patience

v4.51.0: Llama 4, Phi4-Multimodal, DeepSeek-v3, Qwen3

New Model Additions

Llama 4

image

Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.This generation includes two models:

  • The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
  • The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.

Both models leverage early fusion for native multimodality, enabling them to process text and image inputs. Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages (with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).

For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via on-the-fly 4-bit or 8-bit quantization, while Maverick is available in BF16 and FP8 formats. These models are released under the custom Llama 4 Community License Agreement, available on the model repositories

Getting started with Llama 4 using transformers is straightforward. Make sure you have transformers v4.51.0 or later installed:

pip install -U transformers[hf_xet]
</tr></table>

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=transformers&package-manager=pip&previous-version=4.48.0&new-version=4.52.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../python/tools/transformers/models/llama/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 6bd698f8b75b4..e16957eab80a1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,7 +1,7 @@ onnxscript>=0.2.3 optimum>=1.14.1 optree -transformers==4.48.0 +transformers==4.52.1 torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 From 10e2c1e8f72680cfb83b2aba43d19135823a1d36 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 14 Jul 2025 15:09:51 -0700 Subject: [PATCH 32/49] Bump ruff from 0.12.2 to 0.12.3 (#25382) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.12.2 to 0.12.3.
Release notes

Sourced from ruff's releases.

0.12.3

Release Notes

Preview features

  • [flake8-bugbear] Support non-context-manager calls in B017 (#19063)
  • [flake8-use-pathlib] Add autofixes for PTH100, PTH106, PTH107, PTH108, PTH110, PTH111, PTH112, PTH113, PTH114, PTH115, PTH117, PTH119, PTH120 (#19213)
  • [flake8-use-pathlib] Add autofixes for PTH203, PTH204, PTH205 (#18922)

Bug fixes

  • [flake8-return] Fix false-positive for variables used inside nested functions in RET504 (#18433)
  • Treat form feed as valid whitespace before a line continuation (#19220)
  • [flake8-type-checking] Fix syntax error introduced by fix (TC008) (#19150)
  • [pyupgrade] Keyword arguments in super should suppress the UP008 fix (#19131)

Documentation

  • [flake8-pyi] Make example error out-of-the-box (PYI007, PYI008) (#19103)
  • [flake8-simplify] Make example error out-of-the-box (SIM116) (#19111)
  • [flake8-type-checking] Make example error out-of-the-box (TC001) (#19151)
  • [flake8-use-pathlib] Make example error out-of-the-box (PTH210) (#19189)
  • [pycodestyle] Make example error out-of-the-box (E272) (#19191)
  • [pycodestyle] Make example not raise unnecessary SyntaxError (E114) (#19190)
  • [pydoclint] Make example error out-of-the-box (DOC501) (#19218)
  • [pylint, pyupgrade] Fix syntax errors in examples (PLW1501, UP028) (#19127)
  • [pylint] Update missing-maxsplit-arg docs and error to suggest proper usage (PLC0207) (#18949)
  • [flake8-bandit] Make example error out-of-the-box (S412) (#19241)

Contributors

... (truncated)

Changelog

Sourced from ruff's changelog.

0.12.3

Preview features

  • [flake8-bugbear] Support non-context-manager calls in B017 (#19063)
  • [flake8-use-pathlib] Add autofixes for PTH100, PTH106, PTH107, PTH108, PTH110, PTH111, PTH112, PTH113, PTH114, PTH115, PTH117, PTH119, PTH120 (#19213)
  • [flake8-use-pathlib] Add autofixes for PTH203, PTH204, PTH205 (#18922)

Bug fixes

  • [flake8-return] Fix false-positive for variables used inside nested functions in RET504 (#18433)
  • Treat form feed as valid whitespace before a line continuation (#19220)
  • [flake8-type-checking] Fix syntax error introduced by fix (TC008) (#19150)
  • [pyupgrade] Keyword arguments in super should suppress the UP008 fix (#19131)

Documentation

  • [flake8-pyi] Make example error out-of-the-box (PYI007, PYI008) (#19103)
  • [flake8-simplify] Make example error out-of-the-box (SIM116) (#19111)
  • [flake8-type-checking] Make example error out-of-the-box (TC001) (#19151)
  • [flake8-use-pathlib] Make example error out-of-the-box (PTH210) (#19189)
  • [pycodestyle] Make example error out-of-the-box (E272) (#19191)
  • [pycodestyle] Make example not raise unnecessary SyntaxError (E114) (#19190)
  • [pydoclint] Make example error out-of-the-box (DOC501) (#19218)
  • [pylint, pyupgrade] Fix syntax errors in examples (PLW1501, UP028) (#19127)
  • [pylint] Update missing-maxsplit-arg docs and error to suggest proper usage (PLC0207) (#18949)
  • [flake8-bandit] Make example error out-of-the-box (S412) (#19241)
Commits
  • 5bc81f2 Bump 0.12.3 (#19279)
  • 6908e26 Filter ruff_linter::VERSION out of SARIF output tests (#19280)
  • 25c4295 [ty] Avoid stale diagnostics for open files diagnostic mode (#19273)
  • 426fa4b [ty] Add signature help provider to playground (#19276)
  • b0b65c2 [ty] Initial implementation of signature help provider (#19194)
  • 08bc6d2 Add simple integration tests for all output formats (#19265)
  • f2ae12b [flake8-return] Fix false-positive for variables used inside nested functio...
  • 965f415 [ty] Add a --quiet mode (#19233)
  • 83b5bbf Treat form feed as valid whitespace before a line continuation (#19220)
  • 87f6f08 [ty] Make check_file a salsa query (#19255)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.12.2&new-version=0.12.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-lintrunner.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 5bfe909f97aba..309004580d413 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -2,5 +2,5 @@ # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 lintrunner-adapters==0.12.4 -ruff==0.12.2 +ruff==0.12.3 clang-format==20.1.8 From 440ac68a333de47b6c78afe1e4a4523269c990a0 Mon Sep 17 00:00:00 2001 From: Jeff Kilpatrick Date: Mon, 14 Jul 2025 15:36:29 -0700 Subject: [PATCH 33/49] [QNN EP] Upgrade QNN to 2.36.1 (#25388) ### Description Update Qnn default version to 2.36.1.250708 Co-authored-by: Jeff Kilpatrick --- .../android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 2 +- .../github/azure-pipelines/custom-nuget-packaging-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/stages/py-cpu-packaging-stage.yml | 2 +- .../azure-pipelines/templates/android-java-api-aar-test.yml | 2 +- .../github/azure-pipelines/templates/android-java-api-aar.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml | 2 +- .../azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../azure-pipelines/templates/jobs/download_win_qnn_sdk.yml | 2 +- .../ci_build/github/azure-pipelines/templates/py-linux-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-arm64-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-arm64ec-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-x64-qnn.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ee7f8f2fa386a..e5e2a4749ef85 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa25e3f31166a..202aa61da0b80 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 7addb3217072a..69dc9d1a8f63d 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index cf8bbbed70525..526ed71df2006 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index de024f0b3456f..b99246625cb77 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 4fa916db0de39..626a638121858 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 433250f05125e..e2c6b25f48b6d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index ab779e164b36e..74f7f782fe1b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 110f83ff587c8..92e862bd79008 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 535784933a087..5b48a14e2afc3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 3e7427cc7a2e3..930dc83b73460 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index e3f549e2d649f..96eea6cd6d2fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index d533fb7c83ddd..caee5367950e6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index cd060d1fbf19f..185f41822a7e5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 2a2ac49b4e073..9a1e7e5e251c9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 8528fa3907e96..5affc152a0a4a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index b600341827aad..29ebb8c4e4e61 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.0.250627' + QnnSdk: '2.36.1.250708' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 66df2d6053d51..7ebf5394e4530 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index eb77c9422853d..ffeb577547f69 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP' From fab30696f314d75d450b87a4755b3ab63a1520d6 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 15 Jul 2025 14:28:05 +1000 Subject: [PATCH 34/49] Add vendor id to OrtEpFactory and default ORT logger to CreateEpFactories (#25365) ### Description Add vendor id to OrtEpFactory. It's easier to get the vendor id than name on other platforms. Update the selection policy to prefer match on vendor id with fallback to vendor name. Add default ORT logger to CreateEpFactories. The OrtEpFactory currently has no way to log informational messages or issues. CreateEp is given the session logger for use by the OrtEp instance so that part of things is good. Misc cleanups. Make usage of ORT_API2_STATUS and ORT_API_T consistent on onnxruntime_ep_c_api.h. See ort_version_supported in some EP factories where it was missed. ### Motivation and Context Vendor id is easier to match against OrtHardwareDevice when doing auto EP selection. OrtEpFactory should have a logger. Last chance to cleanup APIs before 1.23 release --- .../core/session/onnxruntime_ep_c_api.h | 90 ++++++++++--------- .../providers/cuda/cuda_provider_factory.cc | 8 ++ .../providers/qnn/qnn_provider_factory.cc | 14 ++- onnxruntime/core/session/ep_api_utils.h | 4 + .../core/session/ep_factory_internal.cc | 4 +- .../core/session/ep_factory_internal.h | 4 +- .../core/session/ep_library_internal.cc | 9 +- .../session/ep_library_provider_bridge.cc | 1 + .../core/session/provider_policy_context.cc | 8 +- onnxruntime/test/autoep/library/ep.cc | 6 +- onnxruntime/test/autoep/library/ep.h | 6 +- onnxruntime/test/autoep/library/ep_factory.cc | 7 ++ onnxruntime/test/autoep/library/ep_factory.h | 2 + .../test/framework/ep_plugin_provider_test.cc | 14 +-- 14 files changed, 113 insertions(+), 64 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 44c7bb6ee424a..5d00ce4940d02 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -358,7 +358,7 @@ struct OrtEp { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); + ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr); /** \brief Get information about the nodes supported by the OrtEp instance. * @@ -376,8 +376,8 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, - _Inout_ OrtEpGraphSupportInfo* graph_support_info); + ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance * for each OrtGraph in order to define its computation function. @@ -416,10 +416,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); /** \brief Release OrtNodeComputeInfo instances. * @@ -429,9 +429,9 @@ struct OrtEp { * * \since Version 1.23. */ - void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - _In_ size_t num_node_compute_infos); + ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); /** \brief Get the EP's preferred data layout. * @@ -445,8 +445,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, - _Out_ OrtEpDataLayout* preferred_data_layout); + ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout); /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout * should be converted to `target_data_layout`. @@ -470,11 +469,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, - _In_z_ const char* domain, - _In_z_ const char* op_type, - _In_ OrtEpDataLayout target_data_layout, - _Outptr_ int* should_convert); + ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr, + _In_z_ const char* domain, _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); /** \brief Set dynamic options on this EP. * @@ -492,10 +490,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, - _In_reads_(num_options) const char* const* option_keys, - _In_reads_(num_options) const char* const* option_values, - _In_ size_t num_options); + ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); /** \brief Called by ORT to notify the EP of the start of a run. * @@ -508,8 +506,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options); + ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options); /** \brief Called by ORT to notify the EP of the end of a run. * @@ -524,9 +521,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options, - _In_ bool sync_stream); + ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -586,7 +581,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. * @@ -597,7 +592,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor + ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor /** \brief Get information from the execution provider about OrtHardwareDevice support. * @@ -616,12 +611,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -647,12 +642,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); /** \brief Release the OrtEp instance. * @@ -661,7 +656,18 @@ struct OrtEpFactory { * * \since Version 1.22. */ - void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep); + + /** \brief Get the vendor id who owns the execution provider that the factory creates. + * + * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return vendor_id The vendor ID of the execution provider the factory creates. + * + * \since Version 1.23. + */ + ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr); /** \brief Get the version of the execution provider that the factory creates. * @@ -675,7 +681,7 @@ struct OrtEpFactory { * * \since Version 1.23. */ - const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 2de496a9168a0..f00bf51ae143d 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -313,8 +313,10 @@ CUDA_Provider* GetProvider() { // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -331,6 +333,11 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return ORT_VERSION; } @@ -374,6 +381,7 @@ struct CudaEpFactory : OrtEpFactory { const OrtApi& ort_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name + uint32_t vendor_id{0x1414}; // Microsoft vendor ID }; extern "C" { diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index c679ea1adb286..785177ce37788 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -125,8 +125,10 @@ struct QnnEpFactory : OrtEpFactory { OrtHardwareDeviceType hw_type, const char* qnn_backend_type) : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -142,7 +144,12 @@ struct QnnEpFactory : OrtEpFactory { static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); - return factory->vendor.c_str(); + return factory->ep_vendor.c_str(); + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_vendor_id; } static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -195,8 +202,9 @@ struct QnnEpFactory : OrtEpFactory { } const OrtApi& ort_api; - const std::string ep_name; // EP name - const std::string vendor{"Microsoft"}; // EP vendor name + const std::string ep_name; // EP name + const std::string ep_vendor{"Microsoft"}; // EP vendor name + uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index daccd24453371..a0904c32011a7 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,6 +16,10 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } + static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->GetVendorId(); + } + static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVersion(); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index b289010cc6c5b..fa4ef2515ca92 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,17 +14,19 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, +EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, + vendor_id_{vendor_id}, get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; + OrtEpFactory::GetVendorId = Forward::GetVendorId; OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 087c0c60f8f4e..ee08e2233c529 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -33,12 +33,13 @@ class EpFactoryInternal : public OrtEpFactory { const OrtSessionOptions* session_options, const OrtLogger* logger, std::unique_ptr* ep)>; - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, + EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -67,6 +68,7 @@ class EpFactoryInternal : public OrtEpFactory { private: const std::string ep_name_; // EP name library was registered with const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID const GetSupportedFunc get_supported_func_; // function to return supported devices const CreateFunc create_func_; // function to create the EP instance diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index 25f70f7549a16..ce5736f601b45 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -61,7 +61,8 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } @@ -122,7 +123,8 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { return nullptr; }; - auto dml_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_dml_ep); + auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_dml_ep); return std::make_unique(std::move(dml_factory)); } @@ -170,7 +172,8 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { return nullptr; }; - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_webgpu_ep); + auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_webgpu_ep); return std::make_unique(std::move(webgpu_factory)); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 73423a4744576..70937bdc5d3e8 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -72,6 +72,7 @@ Status EpLibraryProviderBridge::Load() { auto internal_factory = std::make_unique(factory->GetName(factory), factory->GetVendor(factory), + factory->GetVendorId(factory), is_supported_fn, create_fn); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index e8d62ab86f517..211bf8b2d15a4 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -22,7 +22,13 @@ namespace onnxruntime { namespace { bool MatchesEpVendor(const OrtEpDevice* d) { - // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + // match on vendor id if provided + uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory); + if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) { + return true; + } + + // match on vendor name return d->device->vendor == d->ep_vendor; } diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 78261162ebaf8..44b3f9a213abf 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { /*static*/ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) { + OrtEpGraphSupportInfo* graph_support_info) noexcept { ExampleEp* ep = static_cast(this_ptr); size_t num_nodes = 0; @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) { + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { ExampleEp* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -360,7 +360,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const /*static*/ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) { + size_t num_node_compute_infos) noexcept { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index b8c63f39438ba..dfebcc52a0caf 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -31,14 +31,14 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info); + OrtEpGraphSupportInfo* graph_support_info) noexcept; static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos); + size_t num_node_compute_infos) noexcept; OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index d4895102b0bf1..19a44008b8c97 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,6 +14,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -87,6 +88,12 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } +/*static*/ +uint32_t ORT_API_CALL ExampleEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index fda77f12c4814..72fa1c1301841 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -21,6 +21,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; @@ -53,6 +54,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name + const uint32_t vendor_id_{0xB357}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 18bc9cf05b36d..4c5dcd2bd7580 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -36,7 +36,7 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // Individual tests should fill out the other function pointers as needed. } - static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) noexcept { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } @@ -50,7 +50,7 @@ struct TestOrtEpFactory : ::OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { delete static_cast(ep); } }; @@ -125,7 +125,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; return nullptr; }; @@ -135,7 +135,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { #if !defined(ORT_NO_EXCEPTIONS) { - auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = static_cast(-1); return nullptr; }; @@ -144,7 +144,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); }; @@ -167,7 +167,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* node_op_type, OrtEpDataLayout target_data_layout, - int* should_convert) -> ::OrtStatus* { + int* should_convert) noexcept -> ::OrtStatus* { EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); if (node_op_type == std::string_view{"Conv"}) { @@ -201,7 +201,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* /*node_op_type*/, OrtEpDataLayout /*target_data_layout*/, - int* /*should_convert*/) -> ::OrtStatus* { + int* /*should_convert*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "To convert to NHWC or not to convert to NHWC..."); From 9de58ac7a3d18d6ae7f7ae502b3f91361067f1b5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 15 Jul 2025 05:22:56 +0000 Subject: [PATCH 35/49] Bump lintrunner-adapters from 0.12.4 to 0.12.5 (#25380) --- requirements-lintrunner.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 309004580d413..f02e3e8058c29 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,6 +1,6 @@ # This file is auto updated by dependabot # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 -lintrunner-adapters==0.12.4 +lintrunner-adapters==0.12.5 ruff==0.12.3 clang-format==20.1.8 From d9ce6a9d5e29e02fb5dddb8cd064add819d6118d Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 16 Jul 2025 02:54:33 +0800 Subject: [PATCH 36/49] [WebNN] Add rank range validation for rest ops (#25383) - Add common rank range validation to base_op_builder.cc - Handle specific rank range validation for rest ops - Remove duplicated input_shape validation - Fix some typos BTW --- .../builders/impl/argmax_min_op_builder.cc | 18 ---- .../webnn/builders/impl/base_op_builder.cc | 7 +- .../webnn/builders/impl/binary_op_builder.cc | 5 +- .../webnn/builders/impl/concat_op_builder.cc | 3 +- .../webnn/builders/impl/conv_op_builder.cc | 2 +- .../webnn/builders/impl/cumsum_op_builder.cc | 4 - .../webnn/builders/impl/dropout_op_builder.cc | 20 +---- .../webnn/builders/impl/einsum_op_builder.cc | 90 +++++++++++++------ .../impl/gatherElements_op_builder.cc | 5 +- .../builders/impl/gatherND_op_builder.cc | 5 +- .../webnn/builders/impl/gather_op_builder.cc | 26 +----- .../webnn/builders/impl/gru_op_builder.cc | 3 +- .../webnn/builders/impl/logical_op_builder.cc | 4 +- .../webnn/builders/impl/lrn_op_builder.cc | 15 +--- .../webnn/builders/impl/lstm_op_builder.cc | 3 +- .../builders/impl/matMulNBits_op_builder.cc | 19 ++-- .../webnn/builders/impl/max_min_op_builder.cc | 24 +---- .../builders/impl/normalization_op_builder.cc | 87 ++++++++---------- .../webnn/builders/impl/pool_op_builder.cc | 14 --- .../webnn/builders/impl/qdq_op_builder.cc | 3 +- .../builders/impl/reduction_op_builder.cc | 8 +- .../webnn/builders/impl/reshape_op_builder.cc | 5 -- .../impl/scatterElements_op_builder.cc | 5 +- .../builders/impl/scatterND_op_builder.cc | 5 +- .../webnn/builders/impl/slice_op_builder.cc | 21 +++-- .../webnn/builders/impl/softmax_op_builder.cc | 19 ---- .../impl/squeeze_unsqueeze_op_builder.cc | 3 - .../webnn/builders/impl/ternary_op_builder.cc | 3 +- .../webnn/builders/impl/tile_op_builder.cc | 9 -- .../builders/impl/triangular_op_builder.cc | 9 -- .../core/providers/webnn/builders/map_info.h | 2 +- 31 files changed, 167 insertions(+), 279 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index fc630af8cf1e3..fdf1709d87bac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -18,10 +18,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,20 +61,6 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. -bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */, - const Node& node, - WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index b0ec006db6986..3c8e7fa34f7ed 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -62,13 +62,12 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, int32_t input_type; if (!GetType(input, input_type, logger)) return false; - const std::string_view webnn_op_type = GetWebNNOpType(op_type); - if (webnn_op_type.empty()) - return false; + const std::string_view webnn_op_type = GetWebNNOpType(op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, - webnn_input_name, "input", logger); + webnn_input_name, "input", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 280ffc83eae89..851dc373923ac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -73,9 +73,10 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index e0cd48b6883c2..db5e8cd51656c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,8 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b9383a63fe307..e0bfb3bd682e8 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -324,7 +324,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index 7528d9ad2ff51..f3c392b608e45 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -77,10 +77,6 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const std::string axis_name = GetTensorName(input_defs, 1); // Inputs contain optional 'axis' input. const auto* init = graph_viewer.GetConstantInitializer(axis_name); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index c22dd9e97bb1a..37a00fcb12abd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -21,11 +21,6 @@ class DropoutOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,26 +60,13 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, - // beacuse WebNN does not support a constant operand as output. + // because WebNN does not support a constant operand as output. emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); } -// Operator support related. -bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index e5b4fcddc4221..6aa760c0f4baf 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -28,6 +28,8 @@ class EinsumOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& /* node */, const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -42,12 +44,6 @@ enum class RecognizedOperatorType { Total, }; -struct RecognizedOperatorInfo { - RecognizedOperatorType recognized_operator_type; - std::initializer_list component_ranks; - std::initializer_list label_indices; -}; - struct Component { uint32_t label_index_begin; uint32_t label_index_end; @@ -598,7 +594,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } - // tranpose input + // transpose input std::vector permutation(input_labels.size()); for (uint32_t idx = 0; idx < input_labels.size(); idx++) { if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { @@ -620,7 +616,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options_trilu.set("upper", false); output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril - // reducesum to achieve the diagonal values + // reduceSum to achieve the diagonal values std::vector input_shape; std::vector reduced_axes; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); @@ -700,12 +696,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 2) { - // TODO: Support more than two inputs. - LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; - return false; - } - NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -724,13 +714,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, - output_dimensions); - if (recognized_operator_type == RecognizedOperatorType::None) { - LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; - return false; - } - return true; } @@ -738,9 +721,14 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + const std::string_view op_type = node.OpType(); - int32_t input0_type; - int32_t input1_type; + int32_t input0_type, input1_type; bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || @@ -754,6 +742,13 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } + std::vector input0_shape; + std::vector input1_shape; + if (!GetShape(*input_defs[0], input0_shape, logger) || + (has_input1 && !GetShape(*input_defs[1], input1_shape, logger))) { + return false; + } + NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -770,17 +765,54 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, output_dimensions); + std::string_view decomposed_op_type; if (recognized_operator_type == RecognizedOperatorType::None) { LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; return false; - } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { - // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::Multiply) { + decomposed_op_type = "Mul"; } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); - } else { - return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); + decomposed_op_type = "ReduceSum"; + } else if (recognized_operator_type == RecognizedOperatorType::Diagonal) { + decomposed_op_type = "Trilu"; + } else if (recognized_operator_type == RecognizedOperatorType::Transpose) { + decomposed_op_type = "Transpose"; + } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { + decomposed_op_type = "MatMul"; + } else { // Identity + // For the Identity case, we simply forward the input to the output without any modification. + return true; + } + + const std::string_view wnn_input0_name = GetWebNNInputName(decomposed_op_type, 0); + const std::string_view decompose_wnn_op_type = GetWebNNOpType(decomposed_op_type); + if (decompose_wnn_op_type.empty() || + !IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input0_type, + wnn_limits, wnn_input0_name, "inputs", logger) || + !IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input0_name, + input0_shape.size(), node.Name(), logger)) { + return false; + } + + if (has_input1) { + const std::string_view wnn_input1_name = GetWebNNInputName(decomposed_op_type, 1); + return IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input1_type, + wnn_limits, wnn_input1_name, "inputs", logger) && + IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input1_name, + input1_shape.size(), node.Name(), logger); } + + return true; +} + +bool EinsumOpBuilder::HasSupportedOutputsImpl(const Node& /* node */, + const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const { + // The Einsum op produces output with the same data type as its input. + // Therefore, checking the output data type is unnecessary. + // This override prevents calling the base class implementation, as the base implementation + // would return false due to Einsum being a decomposed op. + return true; } void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index b4b9d9a0d4c6b..ae4c3705fdb2e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -61,8 +61,9 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index a15542061dd60..af508c2800f4b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -66,8 +66,9 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 86408557013a0..7111a8f6beaa3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -20,8 +20,6 @@ class GatherOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -50,25 +48,6 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. - -bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto rank = input_shape.size(); - if (rank < 1) { - LOGS(logger, VERBOSE) << "Gather only supports input shapes >= 1D, but input is " - << rank << "d shape"; - return false; - } - - return true; -} - bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; @@ -80,8 +59,9 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod !GetType(indices, indices_type, logger)) return false; - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 6e86ca77464e5..95e75a3083cc2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,8 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 1675615280de9..55d468c4843cb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -91,8 +91,10 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } } + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index 8936bda875aef..e8aab725375ad 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -21,8 +21,6 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, @@ -128,11 +126,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { +bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; @@ -143,12 +140,6 @@ bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - return true; -} - -bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); int32_t input_type = 0; if (!GetType(*input_defs[0], input_type, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index fcdc84b75c048..04d59e2f30d15 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,8 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 111d03571e974..9ab403b7051d2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -48,7 +48,7 @@ void MatMulNBitsBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons // DequantizeLinear + Transpose + MatMul. Given that the CPU EP currently only supports // 4-bit quantization, we only handle 4-bit quantization here. // -// To align with WebNN's dequantizeLinear op contraints, the following transformations are +// To align with WebNN's dequantizeLinear op constraints, the following transformations are // required for MatMulNBits inputs: // 1. B: must be a constant initializer and registered as a 'uint4' WebNN constant with shape // [N, n_blocks_per_col, blob_size * 2]. @@ -159,10 +159,6 @@ bool MatMulNBitsBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } // Inputs B and zero_points (if present) must be initializers if (!graph_viewer.GetConstantInitializer(input_defs[1]->Name())) { // B @@ -193,6 +189,10 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } int32_t A_type = 0; int32_t B_type = 0; @@ -227,10 +227,13 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - // We only support 4-bit quantization, which is represented as the uint4 data type in WebNN. - // Ensure that uint4 is supported. + // Data type: Currently, only 4-bit quantization is supported, represented as the uint4 data type in WebNN. + // Ensure that the uint4 data type is supported by WebNN's dequantizeLinear op. + // Input rank: Only the rank of the first input (A) is flexible. Verify that its rank is supported by + // WebNN's matmul op. return IsDataTypeSupportedByOp("DequantizeLinear", ONNX_NAMESPACE::TensorProto_DataType_UINT4, - wnn_limits, "input", "x", logger); + wnn_limits, "input", "x", logger) && + IsInputRankSupported(wnn_limits, "matmul", "a", input_shape.size(), node.Name(), logger); } bool MatMulNBitsBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4d9cc39bd38fe..9f5ac6ef15735 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -20,8 +20,6 @@ class MaxMinOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node, - WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -68,25 +66,6 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool MaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - const auto& op_type = node.OpType(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_defs.size() < 1) { - LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)"; - return false; - } - - return true; -} - bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -108,7 +87,8 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 148eacac98e4a..9fb643f055ef3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -46,28 +46,14 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - std::vector scale_shape; const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; - ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); - const auto scale_size = scale_shape.size(); - // Except LayerNormalization, other normalization ops' scale input should be 1-D. - if (op_type == "LayerNormalization") { - ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, - "The scale size should be less than or equal to input size."); - } else { - ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); - } - emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); options.set("scale", scale); const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; emscripten::val bias = emscripten::val::undefined(); if (TensorExists(input_defs, bias_input_index)) { - // Bias input exists, and bias's shape should be the same as scale's shape. - std::vector bias_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); - ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); + // Bias input exists. bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -279,12 +265,6 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - LOGS(logger, VERBOSE) << "Cannot get input shape."; - return false; - } - const auto& output_defs = node.OutputDefs(); if (op_type == "SkipSimplifiedLayerNormalization") { if (output_defs.size() > 4) { @@ -316,33 +296,28 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - int32_t input0_type; // input data type - int32_t input1_type; // scale data type - int32_t input2_type; // B data type - int32_t input3_type; // mean data type - int32_t input4_type; // var data type - bool has_input2 = TensorExists(input_defs, 2); - bool has_input3 = TensorExists(input_defs, 3); - bool has_input4 = TensorExists(input_defs, 4); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { - return false; - } - std::vector input_types = {input0_type, input1_type}; - if (has_input2) { - input_types.push_back(input2_type); - } - if (has_input3) { - input_types.push_back(input3_type); + std::vector input_types; + bool all_types_valid = true; + + // Iterate through all inputs and check their existence and types + for (size_t i = 0; i <= input_defs.size(); ++i) { + if (TensorExists(input_defs, i)) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + all_types_valid = false; + break; + } + input_types.push_back(input_type); + } } - if (has_input4) { - input_types.push_back(input4_type); + + // Return false if any input type is invalid + if (!all_types_valid) { + return false; } + + // Check if all input data types are the same if (!AreDataTypesSame(op_type, input_types, logger)) { return false; } @@ -355,13 +330,29 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( - op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { + op_type, webnn_op_type, input_types[0], wnn_limits, webnn_input_name, "input", logger)) { return false; } } - return true; + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + // It's complicated to check all the decomposed ops' input rank support. + // Ensure at least the first input rank is supported by the decomposed ops (pow and div accept the first input). + return IsInputRankSupported(wnn_limits, "pow", "a", input_shape.size(), node.Name(), logger) && + IsInputRankSupported(wnn_limits, "div", "a", input_shape.size(), node.Name(), logger); } else { - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + bool is_data_type_supported = IsDataTypeSupportedByOp(op_type, input_types[0], wnn_limits, "input", "X", logger); + if (op_type == "InstanceNormalization") { + // Skip input rank check for InstanceNormalization, as we will reshape the input to 4D if necessary. + return is_data_type_supported; + } + + // For other ops, check both data type and input rank compatibility. + bool is_input_rank_supported = IsInputRankSupportedByOp(node, wnn_limits, logger); + return is_input_rank_supported && is_data_type_supported; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index f2a3f08b73148..5d921c5176a64 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -133,20 +133,6 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto input_size = input_shape.size(); - if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; - return false; - } - NodeAttrHelper helper(node); if (op_type == "AveragePool" || op_type == "LpPool" || op_type == "MaxPool") { if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index eccf67cc46c9a..053c41773db40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,8 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index a3a0397eda4a3..6ea9b0a440d93 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -128,16 +128,10 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto& op_type = node.OpType(); const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) { - LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; + LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index 8cbb381e0f53e..0444ae3afb56a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -79,11 +79,6 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto& perm_name = input_defs[1]->Name(); const auto* perm_init = graph_viewer.GetConstantInitializer(perm_name); if (!perm_init) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index ae3d559023625..c2974bd988f6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -86,8 +86,9 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const std::string_view op_type = node.OpType(); - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index 5467e91761823..a7788cfd847e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -76,8 +76,9 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& return false; } const std::string_view op_type = node.OpType(); - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8853891ff8ed6..5efbfe932c602 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -136,10 +136,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } if (input_defs.size() < 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " @@ -166,10 +162,17 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& input = *input_defs[0]; - const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + int32_t input_type; - if (!GetType(input, input_type, logger)) + if (!GetType(input, input_type, logger)) { return false; + } + + const std::string_view op_type = node.OpType(); // If there is step < 0, check data type support of reverse. if (TensorExists(input_defs, 4)) { @@ -178,13 +181,15 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { - if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger) || + !IsInputRankSupported(wnn_limits, "reverse", "input", input_shape.size(), node.Name(), logger)) { return false; } } } - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 23e73bb8f1e74..99d137f81864c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -18,11 +18,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -46,20 +41,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool SoftmaxOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 1ba6df9febf14..7e34e35ebac16 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -127,9 +127,6 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewe const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; if (input_defs.size() < 1) { LOGS(logger, ERROR) << op_type << " has no input tensor"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 5d6d59663da61..8973757a24e99 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,8 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 29b232026d7df..24d96588559ae 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -77,15 +77,6 @@ bool TileOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index 5a267557b9454..7a4d172c556fa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -76,15 +76,6 @@ bool TriangularOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto input_size = input_shape.size(); - if (input_size < 2) { - LOGS(logger, VERBOSE) << "Triangular only supports input size >= 2D shape, input is " - << input_size << "d shape"; - return false; - } const std::string diagonal_name = GetTensorName(input_defs, 1); // Inputs contain optional 'diagonal' input. diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index bf95527beb44e..1c30fed7a7916 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,6 +47,7 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, + {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", "Softmax", "Transpose", "Where"}}, @@ -159,7 +160,6 @@ const std::unordered_map op_inputs_map = { {"Softsign", {"softsign", {{0, "input"}}}}, {"Unsqueeze", {"reshape", {{0, "input"}}}}, {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, - {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, {"HardSwish", {"hardSwish", {{0, "input"}}}}, {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, From 1e5fdd1224a445afaff7741da5841464d51c7d83 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 16 Jul 2025 07:02:33 +1000 Subject: [PATCH 37/49] Fix some test issues when WebGPU and DML are enabled in the same build (#25401) ### Description Fix some test setups where both EPs being in the same build wasn't expected. ### Motivation and Context --- .../dml/DmlExecutionProvider/src/ExecutionProvider.cpp | 5 ++++- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 6 ++++-- onnxruntime/test/providers/cpu/math/softmax_test.cc | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index a5066a41981e5..9611cb82d5a62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -781,7 +781,10 @@ namespace Dml // this branch could be reached with a bad custom operator or malformed file. If // a legitimate case reaches here and DML needs to support a new input/output type // besides tensors, then remove the assert. - assert(false); + + // If the model has nodes that use Optional we will arrive here. It's a valid ONNX model but + // TryGetTensorDataType doesn't handle Optional. + // assert(false); nodeContainsSupportedDataTypes = false; return; } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 7b77ca8c69225..4c3f9e8dd4dbd 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -527,18 +527,20 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); -#endif - RunTest(opts, std::move(execution_providers)); +#endif } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 649c9af7cc80b..215203b31f49c 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -61,7 +61,8 @@ TEST(SoftmaxOperator, webgpu_nan) { test.AddOutput("Y", dimensions, expected_result); // explicitly disable for EPs that do not handle NaN - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kCoreMLExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCpuExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}); } #endif From f19bb3c77edd8445940021e0c25af4f97ddbc5d9 Mon Sep 17 00:00:00 2001 From: Nenad Banfic <46795300+nenad1002@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:35:56 -0700 Subject: [PATCH 38/49] Fix SigLIP casual mask bug (#25360) ### Description SigLIP architecture inside the vision encoder should not use a causal mask on the attention. This change will fix Phi 4 MM accuracy issues we have seen. ### Motivation and Context --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../transformers/fusion_attention_clip.py | 70 ++++++++++-------- .../phi-4-v-instruct-vision-attention.onnx | Bin 0 -> 7729 bytes .../python/transformers/test_phi_vision.py | 70 +++++++++++++++--- 3 files changed, 97 insertions(+), 43 deletions(-) create mode 100644 onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index fe93f5cd358bf..8711e368cd1e6 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -269,42 +269,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv add_qk = "" + causal_mask_nodes_1 = None + causal_mask_nodes_2 = None if add_mask is not None: - # 4D Add after Q x K' - add_qk_nodes = self.model.match_parent_path( - add_mask, - [ - "Where", - "Sub", - "Cast", - "Expand", - "Unsqueeze", - "Unsqueeze", - "Reshape", - "Reshape", - "Cast", - ], - [1, 2, 1, 0, 0, 0, 0, 0, 0], - ) - if add_qk_nodes is not None: + if add_mask.input[1] == "attention_mask": add_qk = add_mask.input[1] else: - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes_1 = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - # If the model is exported with batch_size == 1, there is no Concat node - causal_mask_nodes_2 = self.model.match_parent_path( + # 4D Add after Q x K' + add_qk_nodes = self.model.match_parent_path( add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], + [1, 2, 1, 0, 0, 0, 0, 0, 0], ) - if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + if add_qk_nodes is not None: + add_qk = add_mask.input[1] + else: + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes_1 = self.model.match_parent_path( + add_mask, + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], + ) + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes_2 = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + + if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -320,7 +326,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=add_qk, scale=None, - causal=(add_mask is not None), + causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None), ) if new_node is None: logger.debug("fuse_attention: failed to create fused node") diff --git a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34cf26c13d3fc98f8a97aa3f9999e3d99e5bf847 GIT binary patch literal 7729 zcmeI1%Wl&^6hLEFY1~UGX5^MANT3E)mAtT%B7%r2p=KFTgkaarXp*?Ki9O@mV_qv( zhz&fg>%mqFw?GGRPD z6iMVIX>Gab?Cdy=_J>{gs6jd4aVF0ilS#>)A&nF9(&+^(g-Xct2SVJCyL*EHZBmg* zwT3ooH(m^b_z8RKB~O)b+Nf__>R@5;j>$l9`zBPpI1NI<*S~z*et4p3?dyFJIZ@D0 zL@Ev?eAZxw2H4n>(o;?dP8;-i_=>*vf+JsoHQApVTY?g-DHp~oB9;zG)gAfdKKD|e z#U6chVg0q=WYkyAUu+XrcotFLV}rD+pJ@7|twWeA6wFcF+wFZO_p^{TTP<>@FhB(@ z534&KIuGLd%<=kiF%Q0K@D~X)12;dDFf~M$nym-5TbFW2RjNA*fPYCUfrtg19wjXH zZQrm=tuv*&`>a%Y|9FwNO><3W;Eoh5_OidP8dl-WWU{-btBZ66Wi1vBj3>qu89)ZE zG6VLHEp@u=s$U>nhuiw&DIl29N<{02x3AkO5=>89)Y*0b~FfKn9QjWB?hk zKk4}&u9;Q5?oaK1W8~o8xByFPP&G7SL4}liO!j>!lcm%<2Hmg@?oZV=H{q_Defwgz XZm5cGv7%^tn=mTw{Yh>|H`jgvT1N^V literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index 67f89e633a146..d276366706af9 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -149,7 +149,7 @@ def __init__(self): self.attn = PhiVCLIPAttention() self.ln = torch.nn.LayerNorm(20, eps=1e-05) - def forward(self, x): + def forward(self, x, attention_mask=None): # SkipLayerNorm ------+ # | | # Attention | @@ -163,8 +163,7 @@ def forward(self, x): x = self.ln(x) residual = x - # Attention + MatMul - x = self.attn(x) + x = self.attn(x, attention_mask=attention_mask) # SkipLayerNorm x = residual + x @@ -194,14 +193,31 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) def export(self, model, inputs): - torch.onnx.export( - model, - args=inputs, - f=os.path.join(os.path.dirname(__file__), "export.onnx"), - export_params=True, - opset_version=14, - do_constant_folding=True, - ) + path = os.path.join(os.path.dirname(__file__), "export.onnx") + + if len(inputs) == 2: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=["input", "attention_mask"], + dynamic_axes={ + "input": {0: "batch", 1: "seq"}, + "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, + }, + ) + else: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + ) def tearDown(self): path = os.path.join(os.path.dirname(__file__), "export.onnx") @@ -249,6 +265,38 @@ def test_phi_vision_attention(self): ) self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") + def test_phi_vision_attention_with_mask(self): + model = PhiVCLIPAttentionAndLayerNorm() + + batch, seq_len, dim = 1, 2, 20 + mask = torch.zeros(batch, 1, seq_len, seq_len) + mask[:, 1:] = float("-inf") + + inputs = (torch.randn(batch, seq_len, dim), mask) + self.export(model, inputs) + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) + options = FusionOptions("clip") + optimized_model = optimize_model( + original_model, + model_type="clip", + num_heads=2, + hidden_size=20, + optimization_options=options, + opt_level=0, + use_gpu=True, + ) + self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx") + + graph = optimized_model.model.graph + attention_node = next((n for n in graph.node if n.name == "Attention_0"), None) + self.assertIsNotNone(attention_node, "Could not find the Attention fused node") + attr_names = [attr.name for attr in attention_node.attribute] + self.assertNotIn( + "unidirectional", + attr_names, + f"The attention node should not have a 'unidirectional' attribute: {attr_names}", + ) + if __name__ == "__main__": unittest.main() From c7250f4d27f202a7ac4b1b2662e29c218f40ca39 Mon Sep 17 00:00:00 2001 From: derdeljan-msft Date: Wed, 16 Jul 2025 07:30:20 +0200 Subject: [PATCH 39/49] [CPU] GQA supports attention scores output (#25319) ### Description 1. Add optional output to CPU impl of GQA op for storing attention scores (QK). Buffer is of shape (B, N, S, T) and can either be fp16 or fp32, depending on the type of other inputs 2. Add `qk_output` attribute to GQA, which controls if attention scores should be saved before or after softmax is applied 3. Add unit tests to cover this use case 4. Added asserts on other EPs if this feature is used --- docs/ContribOperators.md | 6 +- docs/OperatorKernels.md | 6 +- .../contrib_ops/cpu/bert/attention_common.h | 6 + .../contrib_ops/cpu/bert/gqa_attention_base.h | 69 ++- .../cpu/bert/group_query_attention.cc | 7 +- .../cpu/bert/group_query_attention_helper.h | 31 ++ .../cuda/bert/group_query_attention.cc | 6 + .../rocm/bert/group_query_attention.cu | 4 + .../webgpu/bert/group_query_attention.cc | 4 + .../core/graph/contrib_ops/bert_defs.cc | 80 +++- .../test/python/transformers/test_gqa_cpu.py | 410 +++++++++++++++--- 11 files changed, 522 insertions(+), 107 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9388e7e2a47cd..f3dcde1abe37a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2545,6 +2545,8 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
qk_output : int
+
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
@@ -2584,7 +2586,7 @@ This version of the operator has been available since version 1 of the 'com.micr
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
-#### Outputs +#### Outputs (3 - 4)
output : T
@@ -2593,6 +2595,8 @@ This version of the operator has been available since version 1 of the 'com.micr
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
output_qk (optional) : T
+
Values of QK matrix multiplication, either before or after softmax normalization
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e50702afe9975..fa6c731231405 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 243f611da49e1..80d374d3f0b25 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -53,6 +53,12 @@ enum AttentionKernelType { AttentionKernel_Default }; +enum class QKOutputType : int { + NO_OUTPUT = 0, + BEFORE_SOFTMAX = 1, + AFTER_SOFTMAX = 2 +}; + constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index cec495ef7391e..0d5117709c18a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -35,6 +35,8 @@ class GQAAttentionBase { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; + + qk_output_ = static_cast(info.GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))); } int num_heads_; // number of attention heads of Q @@ -44,6 +46,7 @@ class GQAAttentionBase { bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; + int qk_output_; bool use_smooth_softmax_; @@ -58,6 +61,7 @@ class GQAAttentionBase { Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // output QK buffer const Tensor* seqlens_k, // past sequence lengths tensor GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors @@ -65,6 +69,7 @@ class GQAAttentionBase { const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; const int head_size = parameters.head_size; const int hidden_size = parameters.hidden_size; const bool packed_qkv = parameters.is_packed_qkv; @@ -80,8 +85,7 @@ class GQAAttentionBase { // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && MlasGQASupported(CblasNoTrans, CblasNoTrans); - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * - (gqa_mlas_supported ? sizeof(T) : sizeof(float)); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * (gqa_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -97,11 +101,13 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; + if (gqa_mlas_supported) { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -112,9 +118,9 @@ class GQAAttentionBase { is_prompt, tp, allocator); } else { ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -142,12 +148,14 @@ class GQAAttentionBase { const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) + const size_t total_sequence_length, // total sequence length (T) const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention const T* past_key, // past key only T* present_key, // present key only + T* output_qk, // output QK buffer const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt @@ -199,6 +207,11 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; + T* output_qk_thread = nullptr; + if (output_qk != nullptr) { + const ptrdiff_t output_qk_offset = SafeInt(sequence_length) * total_sequence_length * (batch_index * num_heads_ + head_index); + output_qk_thread = output_qk + output_qk_offset; + } // Compute attention bias offset based on the batch and head indexes // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting @@ -312,13 +325,6 @@ class GQAAttentionBase { } } - if (use_smooth_softmax_ || head_sink != nullptr) { - float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } - // set causal [seq_causal_length, total_seqlen) to 0.f for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { if constexpr (std::is_same::value) { @@ -328,11 +334,30 @@ class GQAAttentionBase { } } + if (qk_output_ == static_cast(QKOutputType::BEFORE_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + + if (use_smooth_softmax_ || head_sink != nullptr) { + float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } + + if (qk_output_ == static_cast(QKOutputType::AFTER_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + output_softmax += present_buffer_sequence_length; if (attention_bias_thread != nullptr) { attention_bias_thread += attention_total_seqlen; } + + if (output_qk_thread != nullptr) { + output_qk_thread += total_sequence_length; + } } } }); @@ -458,6 +483,20 @@ class GQAAttentionBase { SafeInt(sequence_length) * batch_size * num_heads_ * head_size); } } + + template + void WriteOutputQKHeadChunk(T* output_qk, const U* attention_probs, size_t total_sequence_length) const { + if (output_qk == nullptr) { + return; + } + + if constexpr (std::is_same_v) { + std::memcpy(output_qk, attention_probs, SafeInt(total_sequence_length) * sizeof(T)); + } else { + static_assert(std::is_same_v && std::is_same_v); + MlasConvertFloatToHalfBuffer(static_cast(attention_probs), output_qk, total_sequence_length); + } + } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 988151f778806..eb1560ac8e341 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -95,6 +95,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); + std::vector output_qk_shape{static_cast(batch_size), static_cast(num_heads_), static_cast(parameters.sequence_length), static_cast(parameters.total_sequence_length)}; + Tensor* output_qk = context->Output(3, output_qk_shape); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckOutputs(output_qk, qk_output_)); + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -211,7 +216,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, - seqlens_k, parameters, allocator, context); + output_qk, seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0f66119540b03..f01ce985658aa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -398,6 +398,37 @@ Status CheckCustomAttentionInputs(const T* position_ids, return Status::OK(); } +template +Status CheckOutputs(const T* output_qk, int qk_output) { + const bool is_valid_qk_output = qk_output == static_cast(QKOutputType::NO_OUTPUT) || + qk_output == static_cast(QKOutputType::BEFORE_SOFTMAX) || + qk_output == static_cast(QKOutputType::AFTER_SOFTMAX); + if (!is_valid_qk_output) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qk_output attribute received unsupported value ", qk_output); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT) && output_qk == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute was configured but output buffer was not provided"); + } + + return Status::OK(); +} + +inline Status CheckNoQKOutput(int num_outputs, int qk_output) { + if (num_outputs > 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "output_qk optional output is not supported"); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute is not supported"); + } + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 68c4b01d2db20..9cb93cbcd3f32 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -109,6 +109,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + // The current GQA CUDA implementation will never be able to have a QK output. + // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 85aef55908506..09a6550549614 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -213,6 +213,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f3334b13dc645..1f039177b0a21 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -178,6 +178,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& head_sink, params)); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context.OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index c2371487d9187..e2b17aa84d2b1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -6,6 +6,7 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "core/graph/contrib_ops/shape_inference_functions.h" +#include "contrib_ops/cpu/bert/attention_common.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build #if defined(_WIN32) && !defined(NDEBUG) @@ -232,7 +233,8 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c // Type and shape inference for group query attention and sparse attention. void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index = -1, - int use_max_past_present_buffer = -1) { + int use_max_past_present_buffer = -1, + int output_qk_index = -1) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t kv_sequence_length = -1; @@ -277,13 +279,20 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } - if (ctx.getNumOutputs() > 1) { // has present output + if (ctx.getNumOutputs() >= 3) { // has present output // copy the type from query to present key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); // copy the type from query to present value ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + int64_t total_sequence_length_value = 0; + const auto* total_sequence_length_data = ctx.getInputData(6); + if (total_sequence_length_data != nullptr) { + const auto& data = ParseData(total_sequence_length_data); + total_sequence_length_value = static_cast(data[0]); + } + if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); @@ -299,30 +308,25 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else if (use_max_past_present_buffer == 0) { if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value(); + const int64_t present_sequence_length = kv_sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { *present_shape.add_dim() = dim; } - // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(total_sequence_length); + // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) + present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); updateOutputShape(ctx, 1, present_shape); updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { - const auto* total_sequence_length_data = ctx.getInputData(6); - if (total_sequence_length_data != nullptr && past_dims[2].has_dim_value()) { - int64_t total_sequence_length_value = 0; - const auto& data = ParseData(total_sequence_length_data); - total_sequence_length_value = static_cast(data[0]); - + if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { // present_sequence_length = max(past_sequence_length, total_sequence_length) - int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() - ? total_sequence_length_value - : past_dims[2].dim_value(); + const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() + ? total_sequence_length_value + : past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -336,19 +340,50 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } + + if (output_qk_index >= 0) { + const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); + const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); + + if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was provided but qk_output attribute was not configured"); + } + + if (qk_output_type != static_cast(QKOutputType::NO_OUTPUT) && !did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was not provided but qk_output attribute was configured"); + } + + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + if (did_supply_qk_buffer && hasInputShape(ctx, 0) && total_sequence_length_value > 0 && num_heads > 0) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, output_qk_index); + + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims[0].has_dim_value() && query_dims[1].has_dim_value()) { + ONNX_NAMESPACE::TensorShapeProto output_qk_shape; + *output_qk_shape.add_dim() = query_dims[0]; // batch_size + output_qk_shape.add_dim()->set_dim_value(num_heads); // num_heads + *output_qk_shape.add_dim() = query_dims[1]; // sequence_length + output_qk_shape.add_dim()->set_dim_value(total_sequence_length_value); // total_sequence_length + updateOutputShape(ctx, output_qk_index, output_qk_shape); + } + } + } } } } -void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index, int qk_output_index) { // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not constexpr int use_max_past_present_buffer = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { constexpr int use_max_past_present_buffer = 1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + constexpr int qk_output_index = -1; + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } constexpr const char* Attention_ver1_doc = R"DOC( @@ -1127,6 +1162,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Use a smooth factor in softmax.", AttributeProto::INT, static_cast(-1)) + .Attr("qk_output", + "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", + AttributeProto::INT, + static_cast(QKOutputType::NO_OUTPUT)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1205,10 +1244,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") + .Output(3, + "output_qk", + "Values of QK matrix multiplication, either before or after softmax normalization", + "T", + OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3); + GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index ce0649e55f7c5..7f2134b2cda4f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -13,6 +13,7 @@ import random import unittest from dataclasses import dataclass +from enum import Enum import numpy import torch @@ -38,11 +39,17 @@ ATOL = None -class Formats: +class Formats(Enum): BSNH = 0 BNSH = 1 +class QKOutputType(Enum): + NO_OUTPUT = 0 + BEFORE_SOFTMAX = 1 + AFTER_SOFTMAX = 2 + + @dataclass class Config: batch_size: int = 0 @@ -55,6 +62,7 @@ class Config: has_position_ids: bool = False has_attention_bias: bool = False has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT @dataclass @@ -69,6 +77,7 @@ class PromptConfig: has_position_ids: bool = False has_attention_bias: bool = False has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT # LLaMA Microsoft model @@ -153,6 +162,15 @@ def create_group_query_attention_graph_prompt( ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -170,7 +188,7 @@ def create_group_query_attention_graph_prompt( "attention_bias" if config.has_attention_bias else "", "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -179,6 +197,7 @@ def create_group_query_attention_graph_prompt( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -349,6 +368,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.kv_sequence_length, config.kv_sequence_length], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -377,6 +405,15 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -394,7 +431,7 @@ def create_group_query_attention_graph_past( "attention_bias" if config.has_attention_bias else "", "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -403,6 +440,7 @@ def create_group_query_attention_graph_past( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -550,6 +588,15 @@ def create_group_query_attention_graph_past( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.sequence_length, present_kv_seqlen], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -739,6 +786,7 @@ def gqa_prompt_func( position_ids=None, attention_bias=None, head_sink=None, + output_qk=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -771,6 +819,9 @@ def gqa_prompt_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) @@ -778,6 +829,7 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() + ort_outputs = {} if share_buffer: ort_inputs = { @@ -787,7 +839,6 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -830,7 +881,6 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -862,11 +912,21 @@ def gqa_prompt_func( ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) - return output, present_k, present_v + + return output, present_k, present_v, out_qk def gqa_past_func( @@ -882,6 +942,7 @@ def gqa_past_func( position_ids=None, attention_bias=None, head_sink=None, + output_qk=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -914,6 +975,9 @@ def gqa_past_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) @@ -921,6 +985,7 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() + ort_outputs = {} if share_buffer: ort_inputs = { @@ -933,7 +998,6 @@ def gqa_past_func( .cpu() .numpy(), } - if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -983,7 +1047,6 @@ def gqa_past_func( .cpu() .numpy(), } - if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -1016,11 +1079,21 @@ def gqa_past_func( ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) - return output, present_k, present_v + + return output, present_k, present_v, out_qk def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1112,8 +1185,9 @@ def attention_ref( use_smooth_softmax: whether use smooth softmax or not head_sink: (num_heads) or None Output: - output: (batch_size, seqlen_q, num_heads, head_dim) - attention: (batch_size, num_heads, seqlen_q, seqlen_k), softmax after dropout + output: (batch_size, seqlen_q, nheads, head_dim) + masked_scores: (batch_size, nheads, seqlen_q, seqlen_k), before softmax + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) @@ -1132,8 +1206,10 @@ def attention_ref( scores = scores / softcap scores = scores.tanh() scores = scores * softcap + masked_scores = scores.clone() if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + masked_scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -1143,6 +1219,7 @@ def attention_ref( key_padding_mask, q.device, ) + masked_scores.masked_fill_(local_mask, 0.0) scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax or (head_sink is not None): @@ -1168,7 +1245,7 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + return output.to(dtype=dtype_og), masked_scores.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_qkvpacked_ref( @@ -1360,6 +1437,20 @@ def parity_check_gqa_prompt( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) @@ -1370,7 +1461,7 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1393,7 +1484,7 @@ def parity_check_gqa_prompt( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, k, v, @@ -1406,6 +1497,7 @@ def parity_check_gqa_prompt( position_ids, attention_bias, head_sink, + output_qk, left_window_size, past_format, True, @@ -1416,7 +1508,7 @@ def parity_check_gqa_prompt( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, k, v, @@ -1429,6 +1521,7 @@ def parity_check_gqa_prompt( position_ids, attention_bias, head_sink, + output_qk, left_window_size, past_format, True, @@ -1442,6 +1535,22 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1483,6 +1592,8 @@ def parity_check_gqa_prompt( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1591,12 +1702,26 @@ def parity_check_gqa_prompt_no_buff( head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1619,7 +1744,7 @@ def parity_check_gqa_prompt_no_buff( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, None, None, @@ -1632,6 +1757,7 @@ def parity_check_gqa_prompt_no_buff( position_ids, attention_bias, head_sink, + output_qk, left_window_size, past_format, False, @@ -1642,7 +1768,7 @@ def parity_check_gqa_prompt_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, None, None, @@ -1655,6 +1781,7 @@ def parity_check_gqa_prompt_no_buff( position_ids, attention_bias, head_sink, + output_qk, left_window_size, past_format, False, @@ -1668,6 +1795,22 @@ def parity_check_gqa_prompt_no_buff( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1709,6 +1852,8 @@ def parity_check_gqa_prompt_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1834,7 +1979,7 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1873,10 +2018,24 @@ def parity_check_gqa_past( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -1889,6 +2048,7 @@ def parity_check_gqa_past( position_ids, attention_bias, head_sink, + output_qk, past_format, True, left_window_size, @@ -1899,7 +2059,7 @@ def parity_check_gqa_past( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -1912,6 +2072,7 @@ def parity_check_gqa_past( position_ids, attention_bias, head_sink, + output_qk, past_format, True, left_window_size, @@ -1925,6 +2086,22 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1968,6 +2145,8 @@ def parity_check_gqa_past( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2099,7 +2278,7 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -2138,10 +2317,24 @@ def parity_check_gqa_past_no_buff( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length + config.sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -2154,6 +2347,7 @@ def parity_check_gqa_past_no_buff( position_ids, attention_bias, head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2164,7 +2358,7 @@ def parity_check_gqa_past_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -2177,6 +2371,7 @@ def parity_check_gqa_past_no_buff( position_ids, attention_bias, head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2190,6 +2385,22 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Compare results all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET @@ -2229,6 +2440,8 @@ def parity_check_gqa_past_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2257,7 +2470,16 @@ def setUp(self): ] def run_test_config( - self, test_func, config_class, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, additional_params=None + self, + test_func, + config_class, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + additional_params=None, ): if additional_params is None: additional_params = {} @@ -2282,44 +2504,56 @@ def run_test_config( for head_sink in [False, True]: if use_smooth_softmax and head_sink: continue - if config_class == PromptConfig: - config = config_class( - b, - s, - s2, - s + s2 + 8, - n, - n2, - h, - has_pos, - has_attn, - head_sink, - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class( - b, s, s2, sp, n, n2, h, has_pos, has_attn, head_sink - ) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + for output_qk in qk_output: + if config_class == PromptConfig: + config = config_class( + b, + s, + s2, + s + s2 + 8, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class( + b, + s, + s2, + sp, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") @@ -2336,12 +2570,33 @@ def test_gqa_no_past(self): ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_prompt, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_prompt, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) # Test without buffer self.run_test_config( - parity_check_gqa_prompt_no_buff, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias + parity_check_gqa_prompt_no_buff, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, ) def test_gqa_past(self): @@ -2359,11 +2614,25 @@ def test_gqa_past(self): ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output) # Test without buffer - self.run_test_config(parity_check_gqa_past_no_buff, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_past_no_buff, + Config, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2378,6 +2647,7 @@ def test_gqa_interactive_one_batch(self): if pipeline_mode else [(False, False), (True, True), (False, True), (True, False)] ) + qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] @@ -2390,6 +2660,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) self.run_test_config( @@ -2400,6 +2671,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) From c7152ce867b9c9f5fbb855f650e16907c9a09683 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Tue, 15 Jul 2025 22:38:35 -0700 Subject: [PATCH 40/49] [QNN-EP] Support GridSample of linear mode for ONNX opset 20+ (#25408) [QNN-EP] Support GridSample of linear mode for ONNX opset 20+ --- .../builder/opbuilder/simple_op_builder.cc | 10 +++--- .../test/providers/qnn/simple_op_htp_test.cc | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 2650316dd07ac..1c61bda9aeb63 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("bilinear" == mode) { + if ("linear" == mode || "bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 85f8250f70fc5..4c0a53e83e274 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,6 +1254,38 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } +// Test QDQ GridSample with `linear` mode on opset 20+. +TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "linear"), + utils::MakeAttribute("padding_mode", "border")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*op_domain=*/kOnnxDomain, + /*use_contrib_qdq=*/true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. From 7290edcefff666422d2354f31500a137386cb421 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Tue, 15 Jul 2025 22:40:27 -0700 Subject: [PATCH 41/49] [QNN-EP] Update ScatterND op to reject only QNN-CPU (#25403) Current limitation is more than necessary -- only reject when targeting QNN CPU. --- .../core/providers/qnn/builder/opbuilder/simple_op_builder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 1c61bda9aeb63..502ea86b689f4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -60,8 +60,8 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, // To DO: Remove once QNN CPU supports ScatterND const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); if (op_type == "ScatterND") { - ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, - "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); + ORT_RETURN_IF(qnn_backend_type == QnnBackendType::CPU, + "QNN EP does not support ScatterND op on CPU backend. Falling back to ORT CPU."); } // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). From 4a730ca19e5c953e2276b3fac5b1e7e7d11f9317 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 16 Jul 2025 16:53:42 +1000 Subject: [PATCH 42/49] Fix 2 device discovery issues. (#25397) ### Description Fix vendor and device id conversion from SetupApi info. Detect Remote Display Adapter and skip. This results in a bogus device appearing when you're connected to a machine using remote desktop. ### Motivation and Context --- .../core/platform/windows/device_discovery.cc | 79 +++++++++++++------ 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index dcc030cb3467d..fa645939a6395 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -89,23 +89,10 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } -// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") -// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character -// of the string becomes the least significant byte of the integer, and the fourth character -// becomes the most significant byte. -uint32_t WStringToUint32Id(const std::wstring& vendor_name) { - uint32_t vendor_id = 0; - for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { - // For little-endian, place each character at the appropriate byte position - // First character goes into lowest byte, last character into highest byte - vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); - } - return vendor_id; -} - // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE -std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { +std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus, + bool& have_remote_display_adapter) { std::unordered_map device_info; const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; @@ -151,8 +138,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); if (id.size() == 4) { - // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. - return WStringToUint32Id(id); + return static_cast(std::stoul(id, nullptr, 16)); } } @@ -170,6 +156,11 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { + static const std::wstring remote_display_adapter_id(L"RdpIdd_IndirectDisplay"); + if (guid == GUID_DEVCLASS_DISPLAY && remote_display_adapter_id == buffer) { + have_remote_display_adapter = true; + } + continue; } @@ -305,7 +296,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } // returns LUID to DeviceInfo -std::unordered_map GetDeviceInfoD3D12() { +std::unordered_map GetDeviceInfoD3D12(bool have_remote_display_adapter) { std::unordered_map device_info; ComPtr factory; @@ -314,6 +305,8 @@ std::unordered_map GetDeviceInfoD3D12() { return device_info; } + UINT num_adapters = 0; + ComPtr adapter; for (UINT i = 0; factory->EnumAdapters1(i, adapter.ReleaseAndGetAddressOf()) != DXGI_ERROR_NOT_FOUND; ++i) { DXGI_ADAPTER_DESC1 desc; @@ -339,9 +332,12 @@ std::unordered_map GetDeviceInfoD3D12() { info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; + + ++num_adapters; } - // iterate by high-performance GPU preference to add that info + // iterate by high-performance GPU preference to add that info. + UINT cur_adapter = 0; for (UINT i = 0; factory->EnumAdapterByGpuPreference( i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(adapter.ReleaseAndGetAddressOf())) != DXGI_ERROR_NOT_FOUND; @@ -352,12 +348,41 @@ std::unordered_map GetDeviceInfoD3D12() { } uint64_t key = GetLuidKey(desc.AdapterLuid); - auto it = device_info.find(key); - if (it != device_info.end()) { - DeviceInfo& info = it->second; - info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + if (it == device_info.end()) { + continue; } + + DeviceInfo& info = it->second; + + // try and drop the Microsoft Remote Display Adapter. it does not have the DXGI_ADAPTER_FLAG_SOFTWARE flag set + // and the vendor id, device id and description are the same as the real device. the LUID is different to the real + // device. + // Assumption: it will have the worst performance index of the devices we're considering so we only check the + // last adapter + if (num_adapters > 1 && have_remote_display_adapter && cur_adapter == num_adapters - 1) { + ComPtr output; + if (adapter->EnumOutputs(0, &output) == DXGI_ERROR_NOT_FOUND) { + // D3D_DRIVER_TYPE_WARP. Software based or disabled adapter. + // An adapter can be disabled in an RDP session. e.g. integrated GPU is disabled if there's a discrete GPU + + // if we have seen this vendor_id+device_id combination with a different LUID before we drop it. + if (std::any_of(device_info.begin(), device_info.end(), + [key, &info](const auto& entry) { + const auto& entry_info = entry.second; + return key != entry.first && + info.vendor_id == entry_info.vendor_id && + info.device_id == entry_info.device_id; + })) { + device_info.erase(key); + continue; + } + } + } + + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + + ++cur_adapter; } return device_info; @@ -497,10 +522,12 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - // d3d12 info. key is luid - std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(); // setupapi_info. key is vendor_id+device_id - std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); + bool have_remote_display_adapter = false; // set if we see the RdpIdd_IndirectDisplay hardware ID. + std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus, have_remote_display_adapter); + + // d3d12 info. key is luid + std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(have_remote_display_adapter); // Ensure we have at least one CPU bool found_cpu = false; From 0955ab24baa344cf9b5683ba93483d39c9a78c80 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 16 Jul 2025 08:12:39 -0700 Subject: [PATCH 43/49] [webgpu] fix Slice implementation (#25415) ### Description Bugfix: crash when dim_value is 0 ### Motivation and Context Thanks to @skottmckay who found the bug. --- onnxruntime/core/providers/webgpu/tensor/slice.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 39432db5113d1..7e8b434431781 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -172,8 +172,8 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } if (step < 0) { // we are slicing in reverse - start = std::clamp(start, int64_t{0}, dim_value - 1); - end = std::clamp(end, int64_t{-1}, dim_value - 1); + start = dim_value > 0 ? std::clamp(start, int64_t{0}, dim_value - 1) : 0; + end = dim_value > 0 ? std::clamp(end, int64_t{-1}, dim_value - 1) : -1; // note that we are flipping start and end to switch to forward step signs.push_back(-1); steps.push_back(static_cast(-step)); From d22a7819fb8da16c89262743d06a365f664fbf42 Mon Sep 17 00:00:00 2001 From: John Paul Date: Wed, 16 Jul 2025 10:31:42 -0700 Subject: [PATCH 44/49] [QNN EP] Gpu backend test framework & test cases. (#25393) ### Description - Adding test framework and initial batch of test cases for the QNN EP GPU backend. ### Motivation and Context - To ensure the QNN EP GPU backend does not break as ongoing changes are committed to the other QNN backends mainly. --- .../test/providers/qnn/argmaxmin_op_test.cc | 64 +- .../test/providers/qnn/average_pool_test.cc | 57 +- onnxruntime/test/providers/qnn/cast_test.cc | 74 +- .../test/providers/qnn/clip_op_test.cc | 51 +- onnxruntime/test/providers/qnn/conv_test.cc | 741 ++++++++++++------ .../test/providers/qnn/einsum_op_test.cc | 86 ++ .../test/providers/qnn/flatten_op_test.cc | 96 ++- .../test/providers/qnn/gemm_op_test.cc | 307 ++++++-- .../test/providers/qnn/matmul_test.cpp | 148 +++- .../test/providers/qnn/qnn_test_utils.cc | 59 ++ .../test/providers/qnn/qnn_test_utils.h | 11 + 11 files changed, 1298 insertions(+), 396 deletions(-) diff --git a/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc b/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc index 1ee556dfea294..afa92c619e88d 100644 --- a/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc +++ b/onnxruntime/test/providers/qnn/argmaxmin_op_test.cc @@ -36,15 +36,15 @@ static GetTestQDQModelFn BuildQDQArgMxxTestCase(const std::string& op_typ }; } -// Runs an ArgMax/ArgMin model on the QNN CPU backend. Checks the graph node assignment, and that inference +// Runs an ArgMax/ArgMin model on the specified QNN backend. Checks the graph node assignment, and that inference // outputs for QNN EP and CPU EP match. -static void RunCPUArgMxxOpTest(const std::string& op_type, TestInputDef input_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { +static void RunArgMxxOpTest(const std::string& op_type, TestInputDef input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", int opset = 13) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, @@ -77,14 +77,14 @@ static void RunQDQArgMxxOpTest(const std::string& op_type, TestInputDef i // Test that ArgMax/ArgMin with default attributes works on QNN CPU backend. Compares output with CPU EP. TEST_F(QnnCPUBackendTests, ArgMaxMin_DefaultAttrs) { - RunCPUArgMxxOpTest("ArgMax", - TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. - {}, // All default ONNX attributes. - ExpectedEPNodeAssignment::All, 13); - RunCPUArgMxxOpTest("ArgMin", - TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. - {}, // All default ONNX attributes. - ExpectedEPNodeAssignment::All, 13); + RunArgMxxOpTest("ArgMax", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {}, // All default ONNX attributes. + ExpectedEPNodeAssignment::All, "cpu", 13); + RunArgMxxOpTest("ArgMin", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {}, // All default ONNX attributes. + ExpectedEPNodeAssignment::All, "cpu", 13); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -157,6 +157,42 @@ TEST_F(QnnHTPBackendTests, ArgMaxMinU8_RankGreaterThan4_Unsupported) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Test that ArgMax/ArgMin with default attributes works on QNN GPU backend. Compares output with CPU EP. +// Disable Reason : Onnx Op need Int64 output. Can enable after CastOp Int32 to Int64 is done. +// Can enable after CastOp int32 to int64 is implemented in QnnGpu. +TEST_F(QnnGPUBackendTests, DISABLED_ArgMaxMin_DefaultAttrs) { + RunArgMxxOpTest("ArgMax", + TestInputDef({3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {}, // All default ONNX attributes. + ExpectedEPNodeAssignment::All, "gpu", 13); + RunArgMxxOpTest("ArgMin", + TestInputDef({3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {}, // All default ONNX attributes. + ExpectedEPNodeAssignment::All, "gpu", 13); +} + +// Test that ArgMax/ArgMin with axis attribute works on QNN GPU backend. Compares output with CPU EP. +// Disable Reason : Onnx Op need Int64 output. Can enable after CastOp Int32 to Int64 is done. +// Can enable after CastOp int32 to int64 is implemented in QnnGpu. +TEST_F(QnnGPUBackendTests, DISABLED_ArgMaxMin_AxisAttr) { + RunArgMxxOpTest("ArgMax", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {utils::MakeAttribute("axis", static_cast(1))}, // axis is 1 + ExpectedEPNodeAssignment::All, "gpu", 13); + RunArgMxxOpTest("ArgMin", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), // Random input. + {utils::MakeAttribute("axis", static_cast(1))}, // axis is 1 + ExpectedEPNodeAssignment::All, "gpu", 13); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 590694c6fa740..8a0dd60765612 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -24,9 +24,9 @@ static void RunAveragePoolOpTest(const std::string& op_type, const std::vector>& input_defs, const std::vector& attrs, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 18) { + const std::string& backend_name = "cpu", int opset = 18) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs), @@ -217,6 +217,59 @@ TEST_F(QnnHTPBackendTests, AveragePool_3D_AutoPad_SAME_LOWER_u8) { #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#if defined(_M_ARM64) +// +// GPU tests: +// + +// AveragePool with kernel size equal to the spatial dimension of input tensor. +TEST_F(QnnGPUBackendTests, AveragePool_AsGlobal) { + RunAveragePoolOpTest("AveragePool", + {TestInputDef({1, 2, 3, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 18))}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3}), + utils::MakeAttribute("strides", std::vector{3, 3})}, + ExpectedEPNodeAssignment::All, "gpu"); +} + +// Test GlobalAveragePool on QNN GPU backend. +TEST_F(QnnGPUBackendTests, GlobalAveragePool) { + RunAveragePoolOpTest("GlobalAveragePool", + {TestInputDef({1, 2, 3, 3}, false, GetFloatDataInRange(-10.0f, 10.0f, 18))}, + {}, + ExpectedEPNodeAssignment::All, "gpu"); +} + +// AveragePool that counts padding. +TEST_F(QnnGPUBackendTests, AveragePool_CountIncludePad) { + RunAveragePoolOpTest("AveragePool", + {TestInputDef({1, 3, 4, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 3 * 4 * 5))}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3}), + utils::MakeAttribute("count_include_pad", static_cast(1))}, + ExpectedEPNodeAssignment::All, "gpu"); +} + +// AveragePool that use auto_pad 'SAME_UPPER'. +TEST_F(QnnGPUBackendTests, AveragePool_AutopadSameUpper) { + RunAveragePoolOpTest("AveragePool", + {TestInputDef({1, 3, 4, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 3 * 4 * 5))}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3}), + utils::MakeAttribute("count_include_pad", static_cast(1)), + utils::MakeAttribute("auto_pad", "SAME_UPPER")}, + ExpectedEPNodeAssignment::All, "gpu"); +} + +// AveragePool that use auto_pad 'SAME_LOWER'. +TEST_F(QnnGPUBackendTests, AveragePool_AutopadSameLower) { + RunAveragePoolOpTest("AveragePool", + {TestInputDef({1, 3, 4, 5}, false, GetFloatDataInRange(-10.0f, 10.0f, 3 * 4 * 5))}, + {utils::MakeAttribute("kernel_shape", std::vector{3, 3}), + utils::MakeAttribute("count_include_pad", static_cast(1)), + utils::MakeAttribute("auto_pad", "SAME_LOWER")}, + ExpectedEPNodeAssignment::All, "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index 2a63d98ebb37e..98a0fff2b0700 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -50,16 +50,18 @@ static GetTestModelFn BuildCastTestCase(const std::vector& shape, template static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::TensorProto_DataType dst_type, ExpectedEPNodeAssignment expected_ep_assignment, - bool use_htp, + const std::string& backend_name = "cpu", bool enable_fp16_precision = true) { ProviderOptions provider_options; - provider_options["backend_type"] = use_htp ? "htp" : "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; - if (use_htp && enable_fp16_precision) { - provider_options["enable_htp_fp16_precision"] = "1"; - } else { - provider_options["enable_htp_fp16_precision"] = "0"; + if (backend_name == "htp") { + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } } RunQnnModelTest(BuildCastTestCase(shape, dst_type), @@ -99,20 +101,17 @@ static void RunCastFP16HTPTest(const std::vector& shape, // Cast int32_t to float on CPU TEST_F(QnnCPUBackendTests, TestCastInt32ToFloat) { - RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - false); + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All); } // Cast uint8_t to float on CPU TEST_F(QnnCPUBackendTests, TestCastUInt8ToFloat) { - RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - false); + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All); } // Cast float to int32_t on CPU TEST_F(QnnCPUBackendTests, TestCastFloatToInt32) { - RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, - false); + RunCastOpTest({2, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -123,7 +122,7 @@ TEST_F(QnnCPUBackendTests, TestCastFloatToInt32) { // Cast int32_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true, false); + "htp", false); } // Cast uint8_t to float on HTP @@ -131,27 +130,27 @@ TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { // value pair (13, 1.00000012) at index #0 don't match, which is -12 from 13 TEST_F(QnnHTPBackendTests, DISABLED_TestCastUInt8ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true, false); + "htp", false); } // Cast float to int32_t on HTP TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, - true, false); + "htp", false); } // Cast int64_t to int32_t on HTP // Supported in QNN SDK 2.23 TEST_F(QnnHTPBackendTests, TestCastInt64ToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - ExpectedEPNodeAssignment::All, true); + ExpectedEPNodeAssignment::All, "htp"); } // Cast int32_t to int64_t on HTP // Supported in QNN SDK 2.23 TEST_F(QnnHTPBackendTests, TestCastInt32ToInt64HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64, - ExpectedEPNodeAssignment::All, true); + ExpectedEPNodeAssignment::All, "htp"); } // Cast float to bool on HTP. @@ -159,7 +158,7 @@ TEST_F(QnnHTPBackendTests, TestCastFloatToBoolHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL, ExpectedEPNodeAssignment::All, - true); + "htp"); } // Cast float16 to bool on HTP. @@ -170,6 +169,45 @@ TEST_F(QnnHTPBackendTests, TestCastFloat16ToBoolHTP) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Cast int32 to float on GPU +TEST_F(QnnGPUBackendTests, TestCastInt32ToFloat) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "gpu", false); +} + +// Cast uint8 to float on GPU +TEST_F(QnnGPUBackendTests, TestCastUInt8ToFloat) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, + "gpu", false); +} + +// Cast float to int32 on GPU +TEST_F(QnnGPUBackendTests, TestCastFloatToInt32) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, + "gpu", false); +} + +// Cast int64 to int32 on GPU +TEST_F(QnnGPUBackendTests, TestCastInt64ToInt32) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, + ExpectedEPNodeAssignment::All, "gpu"); +} + +// Cast int32 to int64 on GPU +// Disable Reason : Currently not supported. +// Can enable after CastOp int32 to int64 is implemented in QnnGpu. +TEST_F(QnnGPUBackendTests, DISABLED_TestCastInt32ToInt64) { + RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64, + ExpectedEPNodeAssignment::All, "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc index 83296d342e62b..918c477c75a3d 100644 --- a/onnxruntime/test/providers/qnn/clip_op_test.cc +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -20,16 +20,18 @@ template static void RunClipTest(const TestInputDef& input_def, const std::vector>& min_max_defs, ExpectedEPNodeAssignment expected_ep_assignment, - bool on_cpu_backend = true, + const std::string& backend_name = "cpu", int opset = 13, bool enable_fp16_precision = true) { ProviderOptions provider_options; - provider_options["backend_type"] = on_cpu_backend ? "cpu" : "htp"; - - if (!on_cpu_backend && enable_fp16_precision) { - provider_options["enable_htp_fp16_precision"] = "1"; - } else { - provider_options["enable_htp_fp16_precision"] = "0"; + provider_options["backend_type"] = backend_name; + + if (backend_name == "htp") { + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } } RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), @@ -79,24 +81,22 @@ TEST_F(QnnCPUBackendTests, Clip_5D_f32) { // Fails with QNN SDK 2.35.0: // value pair (-4.54545403, -4.54687548) at index #3 don't match, which is -0.00142145 from -4.54545 TEST_F(QnnHTPBackendTests, DISABLED_Clip_f32) { - bool on_cpu_backend = false; RunClipTest(TestInputDef({1, 1, 3, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 12)), {TestInputDef({}, true, {-5.0f}), TestInputDef({}, true, {5.0f})}, ExpectedEPNodeAssignment::All, - on_cpu_backend, + "htp", 13, false); } // Test Clip with int32 on HTP TEST_F(QnnHTPBackendTests, Clip_int32) { - bool on_cpu_backend = false; RunClipTest(TestInputDef({1, 1, 3, 2}, false, {1, 2, -5, 3, -10, 25}), {TestInputDef({}, true, {-5}), TestInputDef({}, true, {5})}, ExpectedEPNodeAssignment::All, - on_cpu_backend); + "htp"); } // Runs a QDQ Clip model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference @@ -235,6 +235,35 @@ TEST_F(QnnHTPBackendTests, Clip_FP16) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Test Clip with float32 on GPU +TEST_F(QnnGPUBackendTests, Clip_fp32) { + RunClipTest(TestInputDef({1, 1, 3, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 12)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All, + "gpu", + 13, + false); +} + +// Test Clip with int32 on GPU +// Disable Reason : Doesn't work. +TEST_F(QnnGPUBackendTests, DISABLED_Clip_int32) { + RunClipTest(TestInputDef({1, 1, 3, 2}, false, {1, 2, -5, 3, -10, 25}), + {TestInputDef({}, true, {-5}), + TestInputDef({}, true, {5})}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index ab716382d3a10..16b18b835a3b1 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -75,19 +75,20 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons // Runs a Conv model on the QNN CPU backend. Checks the graph node assignment, and that inference // outputs for QNN EP and CPU EP match. -static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13, - float fp32_abs_err = 1e-5f) { +static void RunConvOpTest(const std::string& conv_op_type, const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", + int opset = 13, + float fp32_abs_err = 1e-5f) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; auto build_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, @@ -358,128 +359,128 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te // TODO: Segfaults when calling graphFinalize(). v2.13 // fixed by QNN 2.32 TEST_F(QnnCPUBackendTests, Convf32_dynamic_bias) { - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights - TestInputDef({2}, false, -1.0f, 1.0f), // Random dynamic bias - {1, 1}, // default strides - {0, 0, 0, 0}, // default pads - {1, 1}, // default dilations - 1, // default group - "NOTSET", // No auto-padding - ExpectedEPNodeAssignment::All); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights - TestInputDef({2}, false, -1.0f, 1.0f), // Random dynamic bias - {1, 1, 1}, // default strides - {0, 0, 0, 0, 0, 0}, // default pads - {1, 1, 1}, // default dilations - 1, // default group - "NOTSET", // No auto-padding - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, false, -1.0f, 1.0f), // Random dynamic bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All); + + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, false, -1.0f, 1.0f), // Random dynamic bias + {1, 1, 1}, // default strides + {0, 0, 0, 0, 0, 0}, // default pads + {1, 1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All); } // Check that QNN compiles DQ -> Conv -> Q as a single unit. // Tests bias as an initializer. TEST_F(QnnCPUBackendTests, Convf32_bias_initializer) { - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, // default strides - {0, 0, 0, 0}, // default pads - {1, 1}, // default dilations - 1, // default group - "NOTSET", // No auto-padding - ExpectedEPNodeAssignment::All); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, // default strides - {0, 0, 0, 0, 0, 0}, // default pads - {1, 1, 1}, // default dilations - 1, // default group - "NOTSET", // No auto-padding - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All); + + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // default strides + {0, 0, 0, 0, 0, 0}, // default pads + {1, 1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All); } // Tests Conv's auto_pad value "SAME_UPPER" (compares to CPU EP). TEST_F(QnnCPUBackendTests, Convf32_AutoPadUpper) { - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2}, true, -1.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - 1, // default group - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, // strides - {}, // pads - {1, 1, 1}, // dilations - 1, // default group - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, -1.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + 1, // default group + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All); + + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // strides + {}, // pads + {1, 1, 1}, // dilations + 1, // default group + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All); } // Tests ConvTranspose's auto_pad value "SAME_UPPER" (compares to CPU EP). TEST_F(QnnCPUBackendTests, ConvTransposef32_AutoPadUpper) { - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({1, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - 1, // default group - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All); - - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({1, 2, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, // strides - {}, // pads - {1, 1, 1}, // dilations - 1, // default group - "SAME_UPPER", // auto_pad - ExpectedEPNodeAssignment::All); + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + 1, // default group + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All); + + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2, 2}, true, -1.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // strides + {}, // pads + {1, 1, 1}, // dilations + 1, // default group + "SAME_UPPER", // auto_pad + ExpectedEPNodeAssignment::All); } // Tests Conv's auto_pad value "SAME_LOWER" (compares to CPU EP). TEST_F(QnnCPUBackendTests, Convf32_AutoPadLower) { - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - 1, // default group - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({2, 1, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, // strides - {}, // pads - {1, 1, 1}, // dilations - 1, // default group - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + 1, // default group + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All); + + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // strides + {}, // pads + {1, 1, 1}, // dilations + 1, // default group + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All); } // Tests ConvTranspose's auto_pad value "SAME_LOWER" (compares to CPU EP). @@ -487,16 +488,16 @@ TEST_F(QnnCPUBackendTests, Convf32_AutoPadLower) { // unknown file: error: SEH exception with code 0xc0000005 thrown in the test body // fixed by QNN 2.32 TEST_F(QnnCPUBackendTests, ConvTransposef32_AutoPadLower) { - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({1, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, // strides - {}, // pads - {1, 1}, // dilations - 1, // default group - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All); + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // strides + {}, // pads + {1, 1}, // dilations + 1, // default group + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All); } // Tests ConvTranspose's auto_pad value "SAME_LOWER" (compares to CPU EP). @@ -505,45 +506,47 @@ TEST_F(QnnCPUBackendTests, ConvTransposef32_AutoPadLower) { // 0xC0000005: Access violation reading location 0x0000000000000000. // fixed by QNN 2.32 TEST_F(QnnCPUBackendTests, ConvTranspose3D_f32_AutoPadLower) { - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({1, 2, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, // strides - {}, // pads - {1, 1, 1}, // dilations - 1, // default group - "SAME_LOWER", // auto_pad - ExpectedEPNodeAssignment::All); + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3, 3}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2, 2}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // strides + {}, // pads + {1, 1, 1}, // dilations + 1, // default group + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All); } // large input,output, pads TEST_F(QnnCPUBackendTests, Convf32_large_input1_pad_bias_initializer) { - RunCPUConvOpTest("Conv", - TestInputDef({1, 3, 60, 452}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({16, 3, 3, 3}, true, 0.0f, 1.0f), // Random dynamic weights - TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, - {1, 1, 1, 1}, - {1, 1}, - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All, - 13, - 1e-4f); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 3, 60, 452, 20}, false, 0.0f, 10.0f), // Random dynamic input - TestInputDef({16, 3, 3, 3, 3}, true, 0.0f, 1.0f), // Random dynamic weights - TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, - {1, 1, 1, 1, 1, 1}, - {1, 1, 1}, - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All, - 13, - 2e-4f); + RunConvOpTest("Conv", + TestInputDef({1, 3, 60, 452}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({16, 3, 3, 3}, true, 0.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, + {1, 1, 1, 1}, + {1, 1}, + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "cpu", + 13, + 1e-4f); + + RunConvOpTest("Conv", + TestInputDef({1, 3, 60, 452, 20}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({16, 3, 3, 3, 3}, true, 0.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, + {1, 1, 1, 1, 1, 1}, + {1, 1, 1}, + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "cpu", + 13, + 2e-4f); } TEST_F(QnnCPUBackendTests, Convf32_large_input2_nopad_bias_initializer) { @@ -555,76 +558,78 @@ TEST_F(QnnCPUBackendTests, Convf32_large_input2_nopad_bias_initializer) { float fp32_abs_err = 1e-5f; // default value #endif - RunCPUConvOpTest("Conv", - TestInputDef({1, 32, 16, 113}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({16, 32, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias - {1, 1}, - {0, 0, 0, 0}, - {1, 1}, - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All, - 13, // opset - fp32_abs_err); - - RunCPUConvOpTest("Conv", - TestInputDef({1, 32, 16, 113, 12}, false, -3.0f, 3.0f), // Random dynamic input - TestInputDef({16, 32, 1, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights - TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias - {1, 1, 1}, - {0, 0, 0, 0, 0, 0}, - {1, 1, 1}, - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All, - 13, // opset - fp32_abs_err); + RunConvOpTest("Conv", + TestInputDef({1, 32, 16, 113}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({16, 32, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, + {0, 0, 0, 0}, + {1, 1}, + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "cpu", + 13, // opset + fp32_abs_err); + + RunConvOpTest("Conv", + TestInputDef({1, 32, 16, 113, 12}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({16, 32, 1, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, + {0, 0, 0, 0, 0, 0}, + {1, 1, 1}, + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "cpu", + 13, // opset + fp32_abs_err); } // Test 1D Conv with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_StaticWeights_DefaultBias) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, input_data), // Dynamic input - TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {1.0f}), // Initializer Bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer Bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All); } // Test 1D Conv with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, Conv1Df32_DynamicWeights_DefaultBias) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - RunCPUConvOpTest("Conv", - TestInputDef({1, 2, 4}, false, input_data), // Dynamic input - TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef(), // Default bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All); + RunConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef(), // Default bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with static weights (implemented in QNN EP as 2D convolution with height of 1). TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_StaticWeights_DefaultBias) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, input_data), // Dynamic input - TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All); + RunConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All); } // Test 1D ConvTranspose with dynamic weights (implemented in QNN EP as 2D convolution with height of 1). @@ -633,16 +638,16 @@ TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_StaticWeights_DefaultBias) { // fixed by QNN 2.32 TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_DynamicWeights_DefaultBias) { std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - RunCPUConvOpTest("ConvTranspose", - TestInputDef({1, 2, 4}, false, input_data), // Dynamic input - TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights - TestInputDef({1}, true, {0.0f}), // Zero bias - {1}, // Strides - {0, 0}, // Pads - {1}, // Dilations - 1, // default group - "NOTSET", - ExpectedEPNodeAssignment::All); + RunConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), // Dynamic weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -2154,6 +2159,288 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_LargeInput_Dilations_Pads) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Convolution 2D GPU test. +TEST_F(QnnGPUBackendTests, Conv2D) { + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution 3D GPU test. +// Disable Reason : 3D Conv is currently not supported by the GPU. +TEST_F(QnnGPUBackendTests, DISABLED_Conv3D) { + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1, 1}, // default strides + {0, 0, 0, 0, 0, 0}, // default pads + {1, 1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution 2D dynamic bias GPU test. +TEST_F(QnnGPUBackendTests, Conv2D_biasDynamic) { + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, false, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution 2D GPU test, large input feature map, more output feature maps. +TEST_F(QnnGPUBackendTests, Conv2D_largeInput) { + RunConvOpTest("Conv", + TestInputDef({1, 3, 60, 452}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({16, 3, 3, 3}, true, 0.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D GPU test, reduce featuremaps with pointwise conv. +TEST_F(QnnGPUBackendTests, Conv2D_bottleneckSqueeze) { + RunConvOpTest("Conv", + TestInputDef({1, 32, 16, 113}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({16, 32, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({16}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D SAME_UPPER pad GPU test. +TEST_F(QnnGPUBackendTests, Conv2D_padSameUpper) { + RunConvOpTest("Conv", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({2, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {}, // unspecified pads + {1, 1}, // default dilations + 1, // default group + "SAME_UPPER", // auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution Transpose 2D GPU test. +TEST_F(QnnGPUBackendTests, ConvTranspose2D) { + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // default group + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution Transpose 2D SAME_LOWER pad GPU test. +TEST_F(QnnGPUBackendTests, ConvTranspose2D_padSameLower) { + RunConvOpTest("ConvTranspose", + TestInputDef({1, 1, 3, 3}, false, 0.0f, 10.0f), // Random dynamic input + TestInputDef({1, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {}, // unspecified pads + {1, 1}, // default dilations + 1, // default group + "SAME_LOWER", // auto_pad + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Depthwise Convolution 2D GPU test, depthwise conv. +TEST_F(QnnGPUBackendTests, Conv2D_depthwise) { + RunConvOpTest("Conv", + TestInputDef({1, 3, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({3, 1, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({3}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 3, // 3 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution 2D GPU test, reduce featuremaps with depthwise-pointwise conv. +TEST_F(QnnGPUBackendTests, Conv2D_depthwiseSeparable) { + RunConvOpTest("Conv", + TestInputDef({1, 6, 16, 16}, false, -3.0f, 3.0f), // Random dynamic input + TestInputDef({3, 2, 1, 1}, false, -1.0f, 1.0f), // Random dynamic weights + TestInputDef({3}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 3, // 3 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D groups GPU test, use grouping. +TEST_F(QnnGPUBackendTests, Conv2D_groups) { + RunConvOpTest("Conv", + TestInputDef({1, 4, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({2, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 2, // 2 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D groups GPU test, use grouping, more than 1 output per group. +TEST_F(QnnGPUBackendTests, Conv2D_groupsExpand) { + RunConvOpTest("Conv", + TestInputDef({1, 4, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({4, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({4}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 2, // 2 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D groups GPU test, use grouping, 1 group of 3. +TEST_F(QnnGPUBackendTests, Conv2D_1groupOf3) { + RunConvOpTest("Conv", + TestInputDef({1, 3, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({2, 3, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 1, // 1 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D groups GPU test, use grouping, more than 1 group of 3. +// Disable Reason : doesn't work. +TEST_F(QnnGPUBackendTests, DISABLED_Conv2D_2groupsOf3) { + RunConvOpTest("Conv", + TestInputDef({1, 6, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({2, 3, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({2}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 2, // 2 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 2D groups GPU test, use grouping, more than 1 group of 2. +TEST_F(QnnGPUBackendTests, Conv2D_3groupsOf2) { + RunConvOpTest("Conv", + TestInputDef({1, 6, 3, 3}, false, 0.0f, 1.0f), // Random dynamic input + TestInputDef({3, 2, 2, 2}, true, 0.0f, 1.0f), // Random static weights + TestInputDef({3}, true, -1.0f, 1.0f), // Random static bias + {1, 1}, // default strides + {0, 0, 0, 0}, // default pads + {1, 1}, // default dilations + 3, // 3 groups + "NOTSET", // No auto-padding + ExpectedEPNodeAssignment::All, + "gpu", + 13, + 1e-4f); +} + +// Convolution 1D GPU test. +TEST_F(QnnGPUBackendTests, Conv1D) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + RunConvOpTest("Conv", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({1, 2, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {1.0f}), // Initializer Bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Convolution Transpose 1D GPU test. +TEST_F(QnnGPUBackendTests, ConvTranspose1D) { + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + RunConvOpTest("ConvTranspose", + TestInputDef({1, 2, 4}, false, input_data), // Dynamic input + TestInputDef({2, 1, 2}, true, {1.0f, 2.0f, 3.0f, 4.0f}), // Static weights + TestInputDef({1}, true, {0.0f}), // Zero bias + {1}, // Strides + {0, 0}, // Pads + {1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index 55412a7b15d98..73acacb2c76cb 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -37,6 +37,7 @@ constexpr char kEinsumOp[] = "Einsum"; constexpr char kEinsumEquation[] = "equation"; constexpr char kQnnBackendType[] = "backend_type"; constexpr char kQnnBackendTypeCpu[] = "cpu"; +constexpr char kQnnBackendTypeGpu[] = "gpu"; constexpr char kQnnBackendTypeHtp[] = "htp"; constexpr char kOffloadGraphIoQuantization[] = "offload_graph_io_quantization"; constexpr char kOffloadGraphIoQuantizationDisable[] = "0"; @@ -336,6 +337,91 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#if defined(_M_ARM64) +// +// GPU tests: +// + +TEST_F(QnnGPUBackendTests, EinsumRank2) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ij,jk->ik", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnGPUBackendTests, EinsumRank3MatMul) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hij,hjk->hik", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnGPUBackendTests, EinsumRank4MatMul) { + const std::vector shape0{3, 2, 5, 6}; + const std::vector shape1{3, 2, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 6}; + const std::vector shape1{2, 3, 5, 6}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeAll1) { + const std::vector shape0{1, 9, 1, 7}; + const std::vector shape1{1, 7, 1, 9}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 7, 1, 7}; + const std::vector shape1{1, 9, 1, 7}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-4f); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/flatten_op_test.cc b/onnxruntime/test/providers/qnn/flatten_op_test.cc index da2d452c788cf..33849f98709d9 100644 --- a/onnxruntime/test/providers/qnn/flatten_op_test.cc +++ b/onnxruntime/test/providers/qnn/flatten_op_test.cc @@ -17,12 +17,13 @@ namespace test { // Runs a model with a Flatten operator on the QNN CPU backend. Checks the graph node assignment // and that inference outputs for QNN EP and CPU EP match. template -static void RunFlattenTestOnCPU(const TestInputDef& input_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { +static void RunFlattenTest(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", + int opset = 13) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), provider_options, @@ -36,23 +37,23 @@ static void RunFlattenTestOnCPU(const TestInputDef& input_def, // Test that Flatten input (rank4) with axis == 0. TEST_F(QnnCPUBackendTests, Flatten_Rank4_Axis0) { - RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {utils::MakeAttribute("axis", static_cast(0))}, - ExpectedEPNodeAssignment::All); + RunFlattenTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); } // Test that Flatten input (rank4) with axis == -1. TEST_F(QnnCPUBackendTests, Flatten_Rank4_AxisNeg1) { - RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {utils::MakeAttribute("axis", static_cast(-1))}, - ExpectedEPNodeAssignment::All); + RunFlattenTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); } // Test that Flatten input (rank5) with axis == 2. TEST_F(QnnCPUBackendTests, Flatten_Rank5_Axis2) { - RunFlattenTestOnCPU(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), - {utils::MakeAttribute("axis", static_cast(2))}, - ExpectedEPNodeAssignment::All); + RunFlattenTest(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -179,6 +180,73 @@ TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank5_Axis2) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +#if defined(_M_ARM64) +// +// GPU tests: +// + +// float rank4 axis == 0. +TEST_F(QnnGPUBackendTests, Flatten_Rank4_Axis0) { + RunFlattenTest(TestInputDef({2, 3, 4, 5}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// float rank4 axis == -1. +TEST_F(QnnGPUBackendTests, Flatten_Rank4_AxisNeg1) { + RunFlattenTest(TestInputDef({2, 3, 4, 5}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// float rank4 axis == 1. +TEST_F(QnnGPUBackendTests, Flatten_Rank4_Axis1) { + RunFlattenTest(TestInputDef({2, 3, 4, 5}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// float rank4 axis == 2. +TEST_F(QnnGPUBackendTests, Flatten_Rank4_Axis2) { + RunFlattenTest(TestInputDef({2, 3, 4, 5}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// float rank5 axis == 2. +TEST_F(QnnGPUBackendTests, Flatten_Rank5_Axis2) { + RunFlattenTest(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// int32 rank4 Flatten. +TEST_F(QnnGPUBackendTests, Flatten_Int32_Rank4_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunFlattenTest(TestInputDef({1, 3, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// int32 rank4 Flatten. +TEST_F(QnnGPUBackendTests, Flatten_Int32_Rank5_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + RunFlattenTest(TestInputDef({1, 3, 2, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index ef19b37c1eb30..f127e14bde635 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -18,13 +18,14 @@ namespace test { // Runs a model with a Gemm operator on the QNN CPU backend. Checks the graph node assignment // and that inference outputs for QNN EP and CPU EP match. template -static void RunGemmTestOnCPU(const std::vector>& input_defs, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { +static void RunGemmTest(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", + int opset = 13) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; RunQnnModelTest(BuildOpTestCase("Gemm", input_defs, {}, attrs), @@ -40,17 +41,17 @@ static void RunGemmTestOnCPU(const std::vector>& input_de // Test that Gemm with non-default 'alpha' or 'beta' attributes is not supported by QNN EP. TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { // Check that alpha != 1.0f is not supported. - RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), - TestInputDef({2, 4}, false, -10.0f, 10.0f)}, - {utils::MakeAttribute("alpha", 1.5f)}, - ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + RunGemmTest({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. // Check that beta != 1.0f is not supported. - RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), - TestInputDef({2, 4}, false, -10.0f, 10.0f), - TestInputDef({1, 4}, false, -1.0f, 1.0f)}, - {utils::MakeAttribute("beta", 1.2f)}, - ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + RunGemmTest({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. } // Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). @@ -60,17 +61,17 @@ TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); // 2D matrix mul with bias not supported. - RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), - TestInputDef({3, 4}, false, input_b_data), - TestInputDef({2, 4}, false, -1.0f, 1.0f)}, - {}, - ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + RunGemmTest({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. - RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), - TestInputDef({3, 4}, false, input_b_data)}, - {}, - ExpectedEPNodeAssignment::All); // Assigned to QNN EP. + RunGemmTest({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data)}, + {}, + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. } // since Qnn v2.34 value pair (120.73912, 121.73912) at index #0 don't match, which is 1 from 120.739 @@ -79,11 +80,11 @@ TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Dynamic_A_B_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); - RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), - TestInputDef({6, 4}, false, input_b_data), - TestInputDef({1, 4}, false, input_c_data)}, - {}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); } // Test Gemm with static B and Bias inputs. @@ -91,11 +92,11 @@ TEST_F(QnnCPUBackendTests, Gemm_Static_B_And_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); - RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), - TestInputDef({6, 4}, true, input_b_data), - TestInputDef({1, 4}, true, input_c_data)}, - {}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); } // Test Gemm with transposed A/B and static B and Bias inputs. @@ -103,12 +104,12 @@ TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); - RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), - TestInputDef({4, 6}, true, input_b_data), - TestInputDef({1, 4}, true, input_c_data)}, - {utils::MakeAttribute("transA", static_cast(1)), - utils::MakeAttribute("transB", static_cast(1))}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); } // Since Qnn 2.34 value pair (29.4347763, 30.4347763) at index #0 don't match, which is 1 from 29.4348 @@ -117,12 +118,12 @@ TEST_F(QnnCPUBackendTests, DISABLED_Gemm_TransAB_Dynamic_B_And_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); - RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), - TestInputDef({4, 6}, false, input_b_data), - TestInputDef({1, 4}, false, input_c_data)}, - {utils::MakeAttribute("transA", static_cast(1)), - utils::MakeAttribute("transB", static_cast(1))}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); } // Since Qnn 2.34 value pair (11, 10) at index #0 don't match, which is -1 from 11 @@ -135,11 +136,11 @@ TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Broadcast_Bias_DynamicInputs) { // -9.0f, -8.0f, -7.0f // All dynamic inputs - RunGemmTestOnCPU({TestInputDef({2, 4}, false, input_a_data), - TestInputDef({4, 3}, false, input_b_data), - TestInputDef({3}, false, input_c_data)}, - {}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, false, input_b_data), + TestInputDef({3}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); } // TODO: When this is fixed, enable GemmOpTypedTests/0.TestGemmBroadcast test in cpu/math/gemm_test.cc @@ -154,11 +155,11 @@ TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Broadcast_Bias_DynamicA_StaticB_Dynamic // -9.0f, -8.0f, -7.0f // Dynamic A, static B, dynamic C - RunGemmTestOnCPU({TestInputDef({2, 4}, false, input_a_data), - TestInputDef({4, 3}, true, input_b_data), - TestInputDef({3}, false, input_c_data)}, - {}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, true, input_b_data), + TestInputDef({3}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { @@ -170,11 +171,11 @@ TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { // -9.0f, -8.0f, -7.0f // Dynamic A, static B, static C - RunGemmTestOnCPU({TestInputDef({2, 4}, false, input_a_data), - TestInputDef({4, 3}, true, input_b_data), - TestInputDef({3}, true, input_c_data)}, - {}, - ExpectedEPNodeAssignment::All); + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, true, input_b_data), + TestInputDef({3}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); } namespace { @@ -192,12 +193,13 @@ GetTestModelFn BuildReshapeGemmTestCase(const TestInputDef& input, const }; } -void RunCPUReshapeGemmTest(const TestInputDef& input, const TestInputDef& shape, - const TestInputDef& weight, const TestInputDef& bias, - ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f) { +void RunReshapeGemmTest(const TestInputDef& input, const TestInputDef& shape, + const TestInputDef& weight, const TestInputDef& bias, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", float fp32_abs_err = 1e-5f) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; auto build_fn = BuildReshapeGemmTestCase(input, shape, weight, bias); RunQnnModelTest(build_fn, provider_options, 18, expected_ep_assignment, fp32_abs_err); } @@ -209,9 +211,9 @@ TEST_F(QnnCPUBackendTests, ReshapeGemmFusion) { std::vector shape_data = {4, 2}; std::vector weight_data(6, 1.0f); std::vector bias_data = {1.0f, 2.0f, 3.0f}; - RunCPUReshapeGemmTest(TestInputDef({2, 2, 2}, false, input_data), TestInputDef({2}, true, shape_data), - TestInputDef({2, 3}, true, weight_data), TestInputDef({3}, true, bias_data), - ExpectedEPNodeAssignment::All); + RunReshapeGemmTest(TestInputDef({2, 2, 2}, false, input_data), TestInputDef({2}, true, shape_data), + TestInputDef({2, 3}, true, weight_data), TestInputDef({3}, true, bias_data), + ExpectedEPNodeAssignment::All); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -488,6 +490,177 @@ TEST_F(QnnHTPBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Gemm basic default attributes. +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +// Input A's 0th dimension is interpreted as `batch_size`. +TEST_F(QnnGPUBackendTests, Gemm_Basic) { + RunGemmTest({TestInputDef({2, 3}, false, -10.0f, 10.0f), + TestInputDef({3, 4}, false, -10.0f, 10.0f)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Gemm with 'alpha' or 'beta' attributes is not supported by QNN EP. +TEST_F(QnnGPUBackendTests, Gemm_AlphaBetaUnsupported) { + // Check that alpha != 1.0f is not supported. + RunGemmTest({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + "gpu"); + + // Check that beta != 1.0f is not supported. + RunGemmTest({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + "gpu"); +} + +// Gemm with matrix bias ie 2D (M, N) is NOT supported. (Note: vector bias is supported ie when M == 1). +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnGPUBackendTests, Gemm_2DBiasUnsupported) { + // 2D matrix mul with 2D bias not supported. + RunGemmTest({TestInputDef({2, 3}, false, -10.0f, 10.0f), + TestInputDef({3, 4}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + "gpu"); +} + +// Gemm with vector bias is supported ie when M == 1. +// Bias is broadcast across input batches. +// `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnGPUBackendTests, Gemm_1DBiasBcast) { + // 2D matrix mul with 1D bias supported. + RunGemmTest({TestInputDef({2, 3}, false, -10.0f, 10.0f), + TestInputDef({3, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). +TEST_F(QnnGPUBackendTests, Gemm_Dynamic_A_B_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTest({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Test Gemm with static B and Bias inputs. +TEST_F(QnnGPUBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTest({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Test Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnGPUBackendTests, Gemm_TransposeAB_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTest({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnGPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTest({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Bias broadcast across batches. +TEST_F(QnnGPUBackendTests, Gemm_Broadcast_Bias_DynamicInputs) { + std::vector input_a_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector input_b_data(12, 1.0f); + std::vector input_c_data = {1.0f, 2.0f, 3.0f}; + + // All dynamic inputs + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, false, input_b_data), + TestInputDef({3}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +TEST_F(QnnGPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_DynamicC) { + std::vector input_a_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector input_b_data(12, 1.0f); + std::vector input_c_data = {1.0f, 2.0f, 3.0f}; + + // Dynamic A, static B, dynamic C + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, true, input_b_data), + TestInputDef({3}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +TEST_F(QnnGPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { + std::vector input_a_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector input_b_data(12, 1.0f); + std::vector input_c_data = {1.0f, 2.0f, 3.0f}; + + // Dynamic A, static B, static C + RunGemmTest({TestInputDef({2, 4}, false, input_a_data), + TestInputDef({4, 3}, true, input_b_data), + TestInputDef({3}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + "gpu"); +} + +// Tests fusion of Reshape inpout followed by Gemm. +TEST_F(QnnGPUBackendTests, ReshapeGemmFusion) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector shape_data = {4, 2}; + std::vector weight_data(6, 1.0f); + std::vector bias_data = {1.0f, 2.0f, 3.0f}; + RunReshapeGemmTest(TestInputDef({2, 2, 2}, false, input_data), TestInputDef({2}, true, shape_data), + TestInputDef({2, 3}, true, weight_data), TestInputDef({3}, true, bias_data), + ExpectedEPNodeAssignment::All, + "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 723717351ea86..eb06643cfc119 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -26,12 +26,13 @@ static GetTestModelFn BuildMatMulOpTestCase(const TestInputDef& input1_de }; } -static void RunMatMulOpTest(bool is_htp_backend, const std::vector& shape_0, +static void RunMatMulOpTest(const std::vector& shape_0, const std::vector& shape_1, bool is_initializer_0, bool is_initializer_1, ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + const std::string& backend_name = "cpu", int opset = 18, float f32_abs_err = 1e-4f) { ProviderOptions provider_options; - provider_options["backend_type"] = is_htp_backend ? "htp" : "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; RunQnnModelTest(BuildMatMulOpTestCase( @@ -184,33 +185,33 @@ static void RunQDQPerChannelMatMulOpTest( // CPU tests: // TEST_F(QnnCPUBackendTests, MatMulOp) { - // RunMatMulOpTest(is_htp_backend, shape_0, shape_1, is_initializer_0, is_initializer_1) - RunMatMulOpTest(false, {2, 3}, {3, 2}, false, false); - RunMatMulOpTest(false, {2, 3}, {3, 2}, false, true); - RunMatMulOpTest(false, {2, 3}, {3, 2}, true, false); - RunMatMulOpTest(false, {2, 3}, {3, 2}, true, true); // constant folding - RunMatMulOpTest(false, {2, 3}, {2, 3, 2}, false, false); - RunMatMulOpTest(false, {3, 3, 3}, {3, 2}, true, false); - RunMatMulOpTest(false, {2, 3, 3, 3}, {3, 2}, false, true); - RunMatMulOpTest(false, {2, 3, 3, 3}, {2, 3, 3, 2}, false, true); - - RunMatMulOpTest(false, {2, 1, 2, 3}, {3, 3, 2}, false, false); - RunMatMulOpTest(false, {3}, {3}, false, false); - RunMatMulOpTest(false, {3}, {3}, false, true); - RunMatMulOpTest(false, {3}, {3}, true, false); - RunMatMulOpTest(false, {3}, {3, 2}, false, false); - RunMatMulOpTest(false, {3}, {3, 2}, false, true); - RunMatMulOpTest(false, {3}, {3, 3, 2}, true, false); - RunMatMulOpTest(false, {2, 3}, {3}, false, false); - RunMatMulOpTest(false, {2, 3}, {3}, true, false); - RunMatMulOpTest(false, {2, 3, 3, 3}, {3}, false, false); + // RunMatMulOpTest(shape_0, shape_1, is_initializer_0, is_initializer_1) + RunMatMulOpTest({2, 3}, {3, 2}, false, false); + RunMatMulOpTest({2, 3}, {3, 2}, false, true); + RunMatMulOpTest({2, 3}, {3, 2}, true, false); + RunMatMulOpTest({2, 3}, {3, 2}, true, true); // constant folding + RunMatMulOpTest({2, 3}, {2, 3, 2}, false, false); + RunMatMulOpTest({3, 3, 3}, {3, 2}, true, false); + RunMatMulOpTest({2, 3, 3, 3}, {3, 2}, false, true); + RunMatMulOpTest({2, 3, 3, 3}, {2, 3, 3, 2}, false, true); + + RunMatMulOpTest({2, 1, 2, 3}, {3, 3, 2}, false, false); + RunMatMulOpTest({3}, {3}, false, false); + RunMatMulOpTest({3}, {3}, false, true); + RunMatMulOpTest({3}, {3}, true, false); + RunMatMulOpTest({3}, {3, 2}, false, false); + RunMatMulOpTest({3}, {3, 2}, false, true); + RunMatMulOpTest({3}, {3, 3, 2}, true, false); + RunMatMulOpTest({2, 3}, {3}, false, false); + RunMatMulOpTest({2, 3}, {3}, true, false); + RunMatMulOpTest({2, 3, 3, 3}, {3}, false, false); // Failed randomly on Linux // Expected: contains 36 values, where each value and its corresponding value in 16-byte object // <24-00 00-00 00-00 00-00 40-4A 47-42 4D-56 00-00> are an almost-equal pair // Actual: 16-byte object <24-00 00-00 00-00 00-00 80-39 2B-42 4D-56 00-00>, where the value pair (0.104199991, 0) // at index #18 don't match, which is -0.1042 from 0.1042 - // RunMatMulOpTest(false, {2, 3, 3, 3}, {3, 2}, true, false); + // RunMatMulOpTest({2, 3, 3, 3}, {3, 2}, true, false); } #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) @@ -220,33 +221,33 @@ TEST_F(QnnCPUBackendTests, MatMulOp) { // // Disable this for now as the QNN HTP backend is not stable on different versions and platforms so it failed randomly. TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp) { - // RunMatMulOpTest(is_htp_backend, shape_0, shape_1, is_initializer_0, is_initializer_1, expected_ep_assignment, + // RunMatMulOpTest(shape_0, shape_1, is_initializer_0, is_initializer_1, expected_ep_assignment, // opset, f32_abs_err) - RunMatMulOpTest(true, {2, 3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3}, {3, 2}, true, true); // constant folding - RunMatMulOpTest(true, {2, 3}, {2, 3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3, 3, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3, 3, 3}, {2, 3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 1, 2, 3}, {3, 3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {3}, {3, 3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3}, {3}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); - RunMatMulOpTest(true, {2, 3, 3, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3, 2}, true, true, ExpectedEPNodeAssignment::All, "htp"); // constant folding + RunMatMulOpTest({2, 3}, {2, 3, 2}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3, 3, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3, 3, 3}, {2, 3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 1, 2, 3}, {3, 3, 2}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3}, false, true, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({3}, {3, 3, 2}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3}, {3}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); + RunMatMulOpTest({2, 3, 3, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); // Failed randomly on Linux // Expected: contains 18 values, where each value and its corresponding value in 16-byte object // <12-00 00-00 00-00 00-00 40-3D CC-A5 5A-7A 00-00> are an almost-equal pair // Actual: 16-byte object <12-00 00-00 00-00 00-00 80-E8 CF-8F 5B-7A 00-00>, where the value pair // (0.0393999927, 98304.0078) at index #6 don't match, which is 98304 from 0.0394 - // RunMatMulOpTest(true, {3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, 18, 1e-2f); + // RunMatMulOpTest({3, 3, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, "htp", 18, 1e-2f); } TEST_F(QnnHTPBackendTests, MatMulOp_QDQ) { @@ -390,6 +391,67 @@ TEST_F(QnnHTPBackendTests, MatMulOp_QDQ_Regression_uint16_static_weight) { #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +#if defined(_M_ARM64) +// +// GPU tests: +// + +// RunMatMulOpTest(shape_0, shape_1, is_initializer_0, is_initializer_1, expected_ep_assignment, backend); + +TEST_F(QnnGPUBackendTests, MatMulOp_simple) { + RunMatMulOpTest({2, 3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3}, {3, 2}, true, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3}, {3, 2}, true, true, ExpectedEPNodeAssignment::All, "gpu"); // constant folding +} + +TEST_F(QnnGPUBackendTests, MatMulOp_batches) { + RunMatMulOpTest({3, 3, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3, 3, 3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_batchesWtsSameDim) { + RunMatMulOpTest({3, 3, 3}, {3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_batchesWtsSameDim2) { + RunMatMulOpTest({2, 3, 3, 3}, {2, 3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_wtsDimBcast) { + RunMatMulOpTest({3, 3, 3}, {1, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, DISABLED_MatMulOp_batchesDimBcast) { + RunMatMulOpTest({1, 3, 3}, {3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, DISABLED_MatMulOp_batchesDimBcast2) { + RunMatMulOpTest({2, 1, 3, 3}, {3, 3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_inp0DimBcast) { + RunMatMulOpTest({3, 3}, {3, 3, 2}, false, false, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_inp1DimBcast) { + RunMatMulOpTest({2, 3, 3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, "gpu"); +} + +TEST_F(QnnGPUBackendTests, MatMulOp_rank1) { + RunMatMulOpTest({3}, {3}, false, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({3}, {3}, false, true, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({3}, {3}, true, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({3}, {3, 2}, false, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({3}, {3, 2}, false, true, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({3}, {3, 3, 2}, true, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3}, {3}, true, false, ExpectedEPNodeAssignment::All, "gpu"); + RunMatMulOpTest({2, 3, 3, 3}, {3}, false, false, ExpectedEPNodeAssignment::All, "gpu"); +} + +#endif // defined(_M_ARM64) GPU tests + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index cd163b044911c..da6e9c3288328 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -309,6 +309,64 @@ void QnnHTPBackendTests::SetUp() { } } +// Checks if Qnn Gpu backend can run a graph on the system. +// Creates a one node graph with relu op, +// then calls QNN EP's GetCapability() function +// to check if the GPU backend is available. +static BackendSupport GetGPUSupport(const onnxruntime::logging::Logger& logger) { + onnxruntime::Model model("Check if GPU is available", false, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + + // Build simple QDQ graph: DQ -> InstanceNormalization -> Q + auto build_test_case = BuildOpTestCase( + "Relu", + {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f)}, + {}, + {}); + + build_test_case(helper); + helper.SetGraphOutputs(); + auto status = model.MainGraph().Resolve(); + + if (!status.IsOK()) { + return BackendSupport::SUPPORT_ERROR; + } + + // Create QNN EP and call GetCapability(). + MockKernelLookup kernel_lookup; + onnxruntime::GraphViewer graph_viewer(graph); + std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( + {{"backend_type", "gpu"}, {"offload_graph_io_quantization", "0"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability + + qnn_ep->SetLogger(&logger); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + + return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; +} + +void QnnGPUBackendTests::SetUp() { + if (cached_gpu_support_ == BackendSupport::SUPPORTED) { + return; + } + + const auto& logger = DefaultLoggingManager().DefaultLogger(); + + // Determine if GPU backend is supported only if we haven't done so before. + if (cached_gpu_support_ == BackendSupport::SUPPORT_UNKNOWN) { + cached_gpu_support_ = GetGPUSupport(logger); // BackendSupport::SUPPORTED; + } + + if (cached_gpu_support_ == BackendSupport::UNSUPPORTED) { + LOGS(logger, WARNING) << "QNN GPU backend is not available! Skipping test."; + GTEST_SKIP(); + } else if (cached_gpu_support_ == BackendSupport::SUPPORT_ERROR) { + LOGS(logger, ERROR) << "Failed to check if QNN GPU backend is available."; + FAIL(); + } +} + static BackendSupport GetIRSupport(const onnxruntime::logging::Logger& logger); BackendSupport QnnHTPBackendTests::IsIRBackendSupported() const { @@ -425,6 +483,7 @@ BackendSupport QnnCPUBackendTests::cached_cpu_support_ = BackendSupport::SUPPORT BackendSupport QnnHTPBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; BackendSupport QnnIRBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; +BackendSupport QnnGPUBackendTests::cached_gpu_support_ = BackendSupport::SUPPORT_UNKNOWN; bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version) { static const std::unordered_map opset_with_axes_as_input = { diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 9fe48ddabd427..fdcfcfd417424 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -547,6 +547,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe std::string f32_model_data; f32_model_fn(f32_helper); f32_helper.SetGraphOutputs(); + ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); @@ -589,6 +590,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe std::string qdq_model_data; qdq_model_fn(qdq_helper, output_qparams); qdq_helper.SetGraphOutputs(); + ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); @@ -1107,6 +1109,15 @@ class QnnHTPBackendTests : public ::testing::Test { static BackendSupport cached_ir_support_; }; +// Testing fixture class for tests that require the QNN GPU backend. Checks if QNN GPU is available before the test +// begins. The test is skipped if the GPU backend is unavailable (may occur on Windows ARM64). +class QnnGPUBackendTests : public ::testing::Test { + protected: + void SetUp() override; + + static BackendSupport cached_gpu_support_; // Set by the first test using this fixture. +}; + // Testing fixture class for tests that require the QNN CPU backend. Checks if QNN CPU is available before the test // begins. The test is skipped if the CPU backend is unavailable (may occur on Windows ARM64 VM). // TODO: Remove once QNN CPU backend works on Windows ARM64 pipeline VM. From 2d6a52536d96da57e119462e0e5b8096205f27c6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 16 Jul 2025 11:06:56 -0700 Subject: [PATCH 45/49] Fix Build Error when tensor dumping is enabled (#25414) ### Description Fix cuda build error when DEBUG_GENERATION is defined. ### Motivation and Context In https://github.com/microsoft/onnxruntime/pull/24821, a dumping API was removed: `void Print(const char* name, int index, bool end_line)` But related code is not updated. In MatMulNBits, there is a recent change to add bfloat16 support, but the tensor dumper only support BFloat16 but not __nv_bfloat16. This PR adds functions to support __nv_bfloat16 in cuda tensor dumper. --- .../cpu/transformers/beam_search_impl_base.h | 2 +- .../cpu/transformers/beam_search_impl_gpt.h | 9 ++---- .../cpu/transformers/beam_search_impl_t5.h | 32 +++++++++---------- .../transformers/beam_search_impl_whisper.h | 24 +++++++------- .../cpu/transformers/greedy_search_impl_gpt.h | 2 +- .../cpu/transformers/sampling_cpu_helper.h | 3 +- .../contrib_ops/cpu/transformers/sequences.cc | 4 +-- .../cuda/quantization/matmul_nbits.cc | 6 ++-- .../cuda/transformers/sampling_cuda_helper.h | 2 +- .../cuda/utils/dump_cuda_tensor.cc | 3 ++ .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 2 +- 11 files changed, 45 insertions(+), 44 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 29b38fc234de5..c00caf9d8e044 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -289,7 +289,7 @@ Status BeamSearchBase::GenerateNextToken( auto sequences_buffer = cpu_state.sequences.GetCurrentDeviceSequences(); for (int i = 0; i < parameters_->batch_size * parameters_->num_beams; i++) { gsl::span sequence = sequences_buffer.subspan(i * parameters_->max_length, cpu_state.sequences.GetSequenceLength()); - cuda_dumper_->Print("sequences", i, false); + cuda_dumper_->Print(::onnxruntime::MakeString("sequences[", i, "]")); cuda_dumper_->Print(nullptr, sequence.data(), 1, static_cast(sequence.size())); } #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index b18e122980eda..1e2af6394f79c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -278,16 +278,13 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch int iteration_counter = 0; while (current_length < parameters->max_length) { #ifdef DEBUG_GENERATION - auto cur_len = std::to_string(current_length); - dumper->Print("***CurrentLength", cur_len, true); - dumper->Print("iteration", iteration_counter, true); - + dumper->Print(::onnxruntime::MakeString("***CurrentLength=", current_length, ", iteration=", iteration_counter)); dumper->Print("input_ids", feeds[0]); dumper->Print("position_ids", feeds[1]); dumper->Print("attention_mask", feeds[2]); for (size_t i = 3; i < feeds.size(); i++) { - dumper->Print("past", static_cast(i) - 3, true); - dumper->Print("", feeds[i]); + auto name = ::onnxruntime::MakeString("past[", static_cast(i) - 3, "]"); + dumper->Print(name.c_str(), feeds[i]); } #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index c9646cf0fab2e..0fd931e3da150 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -185,13 +185,13 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches #ifdef DEBUG_GENERATION const IConsoleDumper* dumper = this->GetConsoleDumper(); for (int i = 0; i < this->encoder_subgraph_.num_subgraph_inputs; i++) { - dumper->Print("encoder_feeds", static_cast(i), true); - dumper->Print("", encoder_feeds[i]); + auto name = ::onnxruntime::MakeString("encoder_feeds[", i, "]"); + dumper->Print(name.c_str(), encoder_feeds[i]); } for (int i = 0; i <= encoder_subgraph_.GetFirstPresentOutputIndex(); i++) { - dumper->Print("encoder_fetches", i, true); - dumper->Print("", encoder_fetches[i]); + auto name = ::onnxruntime::MakeString("encoder_fetches[", i, "]"); + dumper->Print(name.c_str(), encoder_fetches[i]); } #endif @@ -326,23 +326,23 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches ", start_token_id=", parameters->decoder_start_token_id)); for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { - dumper->Print("decoder_feeds", i, true); - dumper->Print("", decoder_feeds[i]); + auto name = ::onnxruntime::MakeString("decoder_feeds[", i, "]"); + dumper->Print(name.c_str(), decoder_feeds[i]); } for (int i = 0; i < decoder_subgraph_.num_layers; i++) { int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; int self_value_idx = self_key_idx + 1; - dumper->Print("past_key_self", i, true); - dumper->Print("", decoder_feeds[self_key_idx]); - dumper->Print("past_value_self", i + 1, true); - dumper->Print("", decoder_feeds[self_value_idx]); + auto name = ::onnxruntime::MakeString("past_key_self[", i, "]"); + dumper->Print(name.c_str(), decoder_feeds[self_key_idx]); + name = ::onnxruntime::MakeString("past_value_self[", i + 1, "]"); + dumper->Print(name.c_str(), decoder_feeds[self_value_idx]); int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; int cross_value_idx = cross_key_idx + 1; - dumper->Print("past_key_cross", i, true); - dumper->Print("", decoder_feeds[cross_key_idx]); - dumper->Print("past_value_cross", i, true); - dumper->Print("", decoder_feeds[cross_value_idx]); + name = ::onnxruntime::MakeString("past_key_cross[", i, "]"); + dumper->Print(name.c_str(), decoder_feeds[cross_key_idx]); + name = ::onnxruntime::MakeString("past_value_cross[", i, "]"); + dumper->Print(name.c_str(), decoder_feeds[cross_value_idx]); } #endif @@ -363,8 +363,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { - dumper->Print("decoder_fetches", i, true); - dumper->Print("", decoder_fetches[i]); + auto name = ::onnxruntime::MakeString("decoder_fetches[", i, "]"); + dumper->Print(name.c_str(), decoder_fetches[i]); } #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index af0904b7d6e4b..fe0c735792a74 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -191,13 +191,13 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe #ifdef DEBUG_GENERATION const IConsoleDumper* dumper = this->GetConsoleDumper(); for (int i = 0; i < this->encoder_subgraph_.num_subgraph_inputs; i++) { - dumper->Print("encoder_feeds", static_cast(i), true); - dumper->Print("", encoder_feeds[i]); + auto name = ::onnxruntime::MakeString("encoder_feeds[", i, "]"); + dumper->Print(name.c_str(), encoder_feeds[i]); } for (int i = 0; i <= encoder_subgraph_.GetFirstPresentOutputIndex(); i++) { - dumper->Print("encoder_fetches", i, true); - dumper->Print("", encoder_fetches[i]); + auto name = ::onnxruntime::MakeString("encoder_fetches[", i, "]"); + dumper->Print(name.c_str(), encoder_fetches[i]); } #endif @@ -355,16 +355,16 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe while (current_length < parameters->max_length) { iteration_counter++; #ifdef DEBUG_GENERATION - auto cur_len = std::to_string(current_length); - dumper->Print("***CurrentLength", cur_len, true); + auto name = ::onnxruntime::MakeString("***CurrentLength=", current_length, ", iteration_counter=", iteration_counter); for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { - dumper->Print("decoder_feeds", i, true); - dumper->Print("", decoder_feeds[i]); + name = ::onnxruntime::MakeString("decoder_feeds[", i, "]"); + dumper->Print(name.c_str(), decoder_feeds[i]); } + auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); + name = ::onnxruntime::MakeString("past_sequence_length[", offset, "]"); + dumper->Print(name.c_str(), decoder_feeds[offset]); #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS @@ -399,8 +399,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { - dumper->Print("decoder_fetches", i, true); - dumper->Print("", decoder_fetches[i]); + auto name = ::onnxruntime::MakeString("decoder_fetches[", i, "]"); + dumper->Print(name.c_str(), decoder_fetches[i]); } #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 69d25eaabbe02..781a2ec0068ef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -273,7 +273,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ while (current_length < parameters->max_length) { #ifdef DEBUG_GENERATION auto cur_len = std::to_string(current_length); - dumper->Print("***CurrentLength", cur_len, true); + dumper->Print(::onnxruntime::MakeString("***CurrentLength=", cur_len)); dumper->Print("input_ids", feeds[0]); dumper->Print("position_ids", feeds[1]); dumper->Print("attention_mask", feeds[2]); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index 2f41746c1d4e7..3fcbc37bd5eeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -93,7 +93,8 @@ Status Sample(AllocatorPtr& allocator, #ifdef DEBUG_GENERATION dumper->Print("sorted_scores", sorted_scores.data(), parameters->batch_size, parameters->vocab_size); - dumper->Print("sorted_indices", sorted_indices.data(), parameters->batch_size, parameters->vocab_size); + std::vector sorted_indices_copy(sorted_indices.begin(), sorted_indices.end()); + dumper->Print("sorted_indices", sorted_indices_copy.data(), parameters->batch_size, parameters->vocab_size); #endif gsl::span& cumulative_probs = sampling_state->cumulative_probs; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index ecad146da6777..2b2f81e970c78 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -44,8 +44,8 @@ int Sequences::GetMaxLength() const { void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { gsl::span sequence = GetSequence(i); - dumper->Print("sequences", i, false); - dumper->Print(nullptr, sequence.data(), 1, current_length_); + auto name = ::onnxruntime::MakeString("sequences[", i, "]"); + dumper->Print(name.c_str(), sequence.data(), 1, current_length_); } } #endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 8509892919639..2e862ff816bef 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -196,7 +196,7 @@ Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("transposed_scales", transposed_scales, k_blocks, n); + DUMP_TENSOR_D("transposed_scales", transposed_scales, static_cast(k_blocks), static_cast(n)); } return Status::OK(); } @@ -242,7 +242,7 @@ Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("scaled_zero_points", scaled_zero_points, k_blocks, n); + DUMP_TENSOR_D("scaled_zero_points", scaled_zero_points, static_cast(k_blocks), static_cast(n)); } return Status::OK(); } @@ -465,7 +465,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { } } - DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); + DUMP_TENSOR_D("DeQuantized", b_data, static_cast(N_), static_cast(K_padded)); const CudaT alpha = onnxruntime::cuda::OrtToCudaType::FromFloat(1.f); const CudaT zero = onnxruntime::cuda::OrtToCudaType::FromFloat(0.f); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index d1c904987e217..7be3a4851aaed 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -68,7 +68,7 @@ Status Sample(AllocatorPtr& allocator, gsl::span& d_index_out = sampling_state->d_index_out; #ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", sampling_state->temp_storage_bytes, true); + dumper->Print(::onnxruntime::MakeString("temp_storage_bytes=", sampling_state->temp_storage_bytes)); #endif cuda::LaunchSortPairs(storage_buffer.get(), diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index b986f0ae3edad..980299f85f88f 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -309,7 +309,9 @@ CUDA_DUMPER_PRINT_TYPE(BFloat16, BFloat16) CUDA_DUMPER_PRINT_TYPE(UInt4x2, UInt4x2) CUDA_DUMPER_PRINT_TYPE(Int4x2, Int4x2) +// Map some cuda type to ORT type for printing. CUDA_DUMPER_PRINT_TYPE(half, MLFloat16) +CUDA_DUMPER_PRINT_TYPE(__nv_bfloat16, BFloat16) #undef DUMPER_PRINT_TYPE #else @@ -345,6 +347,7 @@ CUDA_DUMPER_PRINT_TYPE(BFloat16) CUDA_DUMPER_PRINT_TYPE(UInt4x2) CUDA_DUMPER_PRINT_TYPE(Int4x2) CUDA_DUMPER_PRINT_TYPE(half) +CUDA_DUMPER_PRINT_TYPE(__nv_bfloat16) #undef DUMPER_PRINT_TYPE #endif diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index ec034bc15341e..3937ce3948de9 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -36,7 +36,7 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { CUDA_DUMPER_PRINT_TYPE(UInt4x2) CUDA_DUMPER_PRINT_TYPE(Int4x2) CUDA_DUMPER_PRINT_TYPE(half) - + CUDA_DUMPER_PRINT_TYPE(__nv_bfloat16) #undef CUDA_DUMPER_PRINT_TYPE }; From 58954ba7a969d7200365498f0d77b24da51e98ac Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Thu, 17 Jul 2025 02:21:06 +0800 Subject: [PATCH 46/49] [webgpu] Apply template to `MatMulNBitsWideTile` (#25353) ### Description This commit applies WGSL template to `MatMulNBitsWideTile` to improve code readability and enables more flexible data handling. As part of this change, support for 4-bit and 8-bit shaders has been consolidated, and a common `CEIL_DIV` utility has been introduced. The previous `ShaderUsage::UseUniform` and `ShaderUsage::UseIndicesTypeAlias` flags are no longer necessary and have been removed. ### Motivation and Context See above --- cmake/onnxruntime_providers_webgpu.cmake | 20 +- .../webgpu/quantization/matmul_nbits.cc | 225 +++++------------- .../webgpu/quantization/matmul_nbits.h | 9 +- .../matmul_nbits_wide_tile.wgsl.template | 192 +++++++++++++++ .../test/contrib_ops/matmul_8bits_test.cc | 9 + 5 files changed, 277 insertions(+), 178 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 2865ad33b39f4..4fdfeae927d9d 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,12 +172,24 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES - "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" - "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") + set(WGSL_SEARCH_PATHS "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + list(APPEND WGSL_SEARCH_PATHS "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") + endif() + file(GLOB_RECURSE WGSL_TEMPLATE_FILES ${WGSL_SEARCH_PATHS}) # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS + "--output" "${WGSL_GENERATED_DIR}" + "-I" "wgsl_template_gen/" + "--preserve-code-ref" + "--verbose" + "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu" + ) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + list(APPEND WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu") + endif() + if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b5de8578af6ba..02ddde00cdabb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -22,6 +22,11 @@ namespace { constexpr unsigned int kMinMForTileOptimization = 4; +template +inline T ceil_div(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -37,165 +42,24 @@ ONNX_OPERATOR_KERNEL_EX( MatMulNBits); Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales", ShaderUsage::UseUniform); + shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("scales", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (has_zero_points_) { - shader.AddInput("zero_points", ShaderUsage::UseUniform); + shader.AddInput("zero_points", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); } - const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - - // Bock size 32, `a` component size 4, 8 `a` components per block. - constexpr uint32_t kAComponentsForBlock32 = 8; + shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8."); ORT_ENFORCE(tile_n_ == workgroup_size, "tile_n must be workgroup_size."); + ORT_ENFORCE(nbits_ == 4 || nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits."); - // memory read/write helpers - shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - << " if (batch < uniforms.input_a_shape[0] && row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " }\n" - << " return input_a_value_t(0);\n" - << "}\n"; - if (nbits_ == 4) { - shader.AdditionalImplementation() << "\n" - << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n" - << " }\n" - << " return input_b_value_t(0);\n" - << "}\n"; - - shader.AdditionalImplementation() << R"( -fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scale : output_element_t) -> mat2x4 { - let lower_values: vec4 = unpack4xU8(packed_value & 0x0F0F0F0Fu); - let upper_values: vec4 = unpack4xU8((packed_value >> 4u) & 0x0F0F0F0Fu); - - let zero_matrix: mat2x4 = mat2x4( - zero_point, zero_point, zero_point, zero_point, - zero_point, zero_point, zero_point, zero_point - ); - - var dequantized_values: mat2x4 = mat2x4( - output_element_t(lower_values[0]), output_element_t(upper_values[0]), - output_element_t(lower_values[1]), output_element_t(upper_values[1]), - output_element_t(lower_values[2]), output_element_t(upper_values[2]), - output_element_t(lower_values[3]), output_element_t(upper_values[3]) - ); - - dequantized_values = (dequantized_values - zero_matrix) * scale; - return dequantized_values; -} -)"; - } - - shader.AdditionalImplementation() << "\n" - << "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return scales[row * uniforms.input_b_shape[1] + col];\n" - << " }\n" - << " return output_element_t(0);\n" - << "}\n" - << GenerateZeroPointReadingCode(nbits_, has_zero_points_); - - shader.AdditionalImplementation() << "\n" - << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" - << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n" - << " }\n" - << "}\n"; - - // declare const variables - shader.AdditionalImplementation() << "\n" - << "// A block32 containing 8 components of `a`." << "\n" - << "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n" - << "const kTileM = " << tile_m_ << "u;\n" - << "const kTileN = " << tile_n_ << "u;\n"; - - // declare workgroup memory - shader.AdditionalImplementation() << "\n" - << "var a_data_tile: array, kTileM>;\n" - << "\n"; - - // main - shader.MainFunctionBody() << R"MAIN_FN( - let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile); - let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM; - let col = (workgroup_idx % uniforms.num_N_tile) * kTileN; - - let a_elements_per_col = uniforms.input_a_shape[2]; - let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32; - - // Utilizing an f32 accumulator mitigated precision loss with minimal - // performance impact compared to an f16 accumulator. - var results : array; - for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { - // Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32) - let a_row_idx = local_idx / kAComponentsForBlock32; - let a_col_idx = local_idx % kAComponentsForBlock32; - a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx); - workgroupBarrier(); - - let b_row = col + local_idx; - let b_col = a_block_idx; - - let scale = mm_read_scale(b_row, b_col); - let zero_point = mm_read_zero(b_row, b_col, uniforms.input_b_shape[0], uniforms.zero_blocks_per_col); -)MAIN_FN"; - - if (nbits_ == 4) { - shader.MainFunctionBody() << R"MAIN_FN( - let b_data = mm_read_b(b_row, b_col); - // `b` component size is 4. - for (var b_idx = 0u; b_idx < 4u; b_idx++) { - let b_dequantized = dequantize_packed8xU4(b_data[b_idx], zero_point, scale); - for (var m_idx = 0u; m_idx < kTileM; m_idx++) { - let a_data0 = a_data_tile[m_idx][b_idx * 2u]; - let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u]; - - results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + f32(dot(a_data1, b_dequantized[1])); - } - } -)MAIN_FN"; - } else { - shader.MainFunctionBody() << " var b_data0 = vec4(0);\n" - " var b_data1 = vec4(0);\n" - " if (b_row < uniforms.input_b_shape[0] && b_col < uniforms.input_b_shape[1]) {\n" - << " b_data0 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 0)") << ";\n" - << " b_data1 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 1)") << ";\n" - " }" - << R"MAIN_FN( - for (var b_idx = 0u; b_idx < 4u; b_idx++) { - let b_dequantized0 = (vec4(unpack4xU8(b_data0[b_idx])) - vec4(zero_point)) * scale; - let b_dequantized1 = (vec4(unpack4xU8(b_data1[b_idx])) - vec4(zero_point)) * scale; - for (var m_idx = 0u; m_idx < kTileM; m_idx++) { - let a_data0 = a_data_tile[m_idx][b_idx]; - let a_data1 = a_data_tile[m_idx][b_idx + 4u]; - - results[m_idx] += f32(dot(a_data0, b_dequantized0)) + f32(dot(a_data1, b_dequantized1)); - } - } -)MAIN_FN"; - } - - shader.MainFunctionBody() << R"MAIN_FN( - - workgroupBarrier(); - } - - if (batch >= uniforms.input_a_shape[0]) { - return; - } - - // Write the results. - for (var m_idx = 0u; m_idx < kTileM; m_idx++) { - mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); - } -)MAIN_FN"; - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_wide_tile.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), + WGSL_TEMPLATE_PARAMETER(nbits, nbits_), + WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), + WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_)); } // Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm. @@ -408,7 +272,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization; + const bool use_wide_tile_program = block_size == 32 && + components_a == 4 && + components_b == 4 && + M >= kMinMForTileOptimization; if (use_wide_tile_program) { // Enforce output components to 1. components = 1; @@ -416,30 +283,44 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t workgroup_size = 128; constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; - uint32_t num_N_tile = (N + tile_n - 1) / tile_n; - uint32_t num_M_tile = (M + tile_m - 1) / tile_m; + const uint32_t num_N_tile = ceil_div(N, tile_n); + const uint32_t num_M_tile = ceil_div(M, tile_m); MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits}; program.SetWorkgroupSize(workgroup_size); - program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, - (M + tile_m - 1) / tile_m, - batch_count); - program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32"); - - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; - - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow(components_b * 4)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow(components)}) - .AddUniformVariables({{block_size}, {zero_blocks_per_col}, {num_N_tile}, {num_M_tile}}) - .CacheHint(nbits, has_zero_points); + program.SetDispatchGroupSize(num_N_tile, num_M_tile, batch_count); + + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = n_blocks_per_col * blob_size / components_b_with_u32; + const uint32_t K_of_a = K / components_a; + + program.AddInput({a, + ProgramTensorMetadataDependency::TypeAndRank, + onnxruntime::narrow(components_a)}); + program.AddInput({b, + ProgramTensorMetadataDependency::TypeAndRank, + onnxruntime::narrow(components_b_with_u32)}); + program.AddInput({scales, ProgramTensorMetadataDependency::TypeAndRank}); if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + program.AddInput({zero_points, + ProgramTensorMetadataDependency::TypeAndRank, + {ceil_div(zero_points->Shape().Size(), static_cast(4))}, + 4}); } + program.AddOutput({y, + ProgramTensorMetadataDependency::TypeAndRank, + onnxruntime::narrow(components)}); + program.AddUniformVariables({{batch_count}, + {M}, + {N}, + {K_of_a}, + {K_of_b}, + {n_blocks_per_col}, + {zero_blocks_per_col}, + {num_N_tile}, + {num_M_tile}}); + program.CacheHint(nbits, has_zero_points); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 807576c91752b..aabc73ca05d03 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -15,10 +15,15 @@ using namespace onnxruntime::webgpu; class MatMulNBitsWideTileProgram final : public Program { public: MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) - : Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} + : Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"Batch", ProgramUniformVariableDataType::Uint32}, + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"n_blocks_per_col", ProgramUniformVariableDataType::Uint32}, {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, {"num_N_tile", ProgramUniformVariableDataType::Uint32}, {"num_M_tile", ProgramUniformVariableDataType::Uint32}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template new file mode 100644 index 0000000000000..e030f00c084e9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template @@ -0,0 +1,192 @@ + +#param has_zero_points +#param nbits +#param tile_m +#param tile_n + +// Only support Block32 at the moment. +const KAVecSizeForBlock32 = 8u; + +const kTileM : u32 = tile_m; +const kTileN : u32 = tile_n; + +// TODO: Move to matmulnbits_common template +#if has_zero_points +fn load_zero(row : u32, col : u32, r_dim : u32, c_dim : u32) -> output_element_t { +#if nbits == 4 + const elements_in_uint32 = 8u; +#else // nbits == 8 + const elements_in_uint32 = 4u; +#endif + + const bits : u32 = nbits; + + if (row < r_dim && col < c_dim) { + let offset = row * c_dim + col; + + // u32 holds elements_in_uint32 packed nbits. + let array_index = offset / elements_in_uint32; + let component_index = offset % elements_in_uint32; + let packed_value = zero_points[array_index]; + + // Extract the nbits component + let shift_amount = component_index * bits; + +#if nbits == 4 + let masked_value = (packed_value >> shift_amount) & 0xF; +#else // nbits == 8 + let masked_value = (packed_value >> shift_amount) & 0xFF; +#endif + + return output_element_t(masked_value); + } + + return output_element_t(); +} +#else +fn load_zero(row : u32, col : u32, r_dim : u32, c_dim : u32) -> output_element_t { +#if nbits == 4 + return output_element_t(8); +#else // nbits == 8 + return output_element_t(128); +#endif +} +#endif + +fn load_a(batch : u32, row : u32, col : u32) -> input_a_value_t { + if (batch < uniforms.Batch && row < uniforms.M && col < uniforms.K_of_a) { + let offset = batch * uniforms.M * uniforms.K_of_a + row * uniforms.K_of_a + col; + return input_a[offset]; + } + return input_a_value_t(); +} + +fn load_scale(row : u32, block_idx : u32) -> output_element_t { + if (row < uniforms.N && block_idx < uniforms.n_blocks_per_col) { + let offset = row * uniforms.n_blocks_per_col + block_idx; + return scales[offset]; + } + return output_element_t(); +} + +fn write_output(batch : u32, row : u32, col : u32, value : output_element_t) { + if (batch < uniforms.Batch && row < uniforms.M && col < uniforms.N) { + let offset = batch * uniforms.M * uniforms.N + row * uniforms.N + col; + output[offset] = value; + } +} + +#if nbits == 4 +fn load_b(row : u32, block_idx : u32) -> vec4 { + if (row < uniforms.N && block_idx < uniforms.K_of_b) { + let offset = row * uniforms.K_of_b + block_idx; + return input_b[offset]; + } + return vec4(); +} + +// packed8xU4 +fn dequantize(packed_data : u32, + zero_point : output_element_t, + scale : output_element_t) -> mat2x4 { + let lower : vec4 = unpack4xU8(packed_data & 0x0F0F0F0Fu); + let upper : vec4 = unpack4xU8((packed_data >> 4u) & 0x0F0F0F0Fu); + + let zero_matrix : mat2x4 = mat2x4( + zero_point, zero_point, zero_point, zero_point, + zero_point, zero_point, zero_point, zero_point); + + var dequantized_values : mat2x4 = mat2x4( + output_element_t(lower[0]), output_element_t(upper[0]), + output_element_t(lower[1]), output_element_t(upper[1]), + output_element_t(lower[2]), output_element_t(upper[2]), + output_element_t(lower[3]), output_element_t(upper[3])); + + dequantized_values = (dequantized_values - zero_matrix) * scale; + return dequantized_values; +} +#else // nbits == 8 +fn load_b(row : u32, block_idx : u32) -> array, 4> { + if (row < uniforms.N) { + let offset = 2 * block_idx; + let b_data_0 = select(input_b_value_t(), + input_b[row * uniforms.K_of_b + offset], + offset < uniforms.K_of_b); + let b_data_1 = select(input_b_value_t(), + input_b[row * uniforms.K_of_b + offset + 1], + offset + 1 < uniforms.K_of_b); + + let b_data = array, 4>( + vec2(b_data_0[0], b_data_0[1]), + vec2(b_data_0[2], b_data_0[3]), + vec2(b_data_1[0], b_data_1[1]), + vec2(b_data_1[2], b_data_1[3])); + return b_data; + } + return array, 4>(); +} + +// 2x packed4xU8 +fn dequantize(packed_data : vec2, + zero_point : output_element_t, + scale : output_element_t) -> mat2x4 { + let lower : vec4 = unpack4xU8(packed_data[0]); + let upper : vec4 = unpack4xU8(packed_data[1]); + + let zero_matrix : mat2x4 = mat2x4( + zero_point, zero_point, zero_point, zero_point, + zero_point, zero_point, zero_point, zero_point); + + var dequantized_values : mat2x4 = mat2x4( + output_element_t(lower[0]), output_element_t(lower[1]), + output_element_t(lower[2]), output_element_t(lower[3]), + output_element_t(upper[0]), output_element_t(upper[1]), + output_element_t(upper[2]), output_element_t(upper[3])); + + dequantized_values = (dequantized_values - zero_matrix) * scale; + return dequantized_values; +} +#endif + +var a_data_tile : array, kTileM>; + +$MAIN { + let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile); + let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM; + let col = (workgroup_idx % uniforms.num_N_tile) * kTileN; + + // Utilizing an f32 accumulator mitigated precision loss with minimal + // performance impact compared to an f16 accumulator. + var results : array; + for (var block_idx = 0u; block_idx < uniforms.n_blocks_per_col; block_idx++) { + // Load `a` elements into workgroup memory, TileM x KAVecSizeForBlock32 (block32) + let a_row_idx = local_idx / KAVecSizeForBlock32; + let a_col_idx = local_idx % KAVecSizeForBlock32; + a_data_tile[a_row_idx][a_col_idx] = load_a(batch, + row + a_row_idx, + block_idx * KAVecSizeForBlock32 + a_col_idx); + workgroupBarrier(); + + let b_row = col + local_idx; + let scale = load_scale(b_row, block_idx); + let zero_point = load_zero(b_row, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + let b_data = load_b(b_row, block_idx); + + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let b_dequantized = dequantize(b_data[b_idx], zero_point, scale); + for (var m_idx = 0u; m_idx < kTileM; m_idx++) { + let a_data0 = a_data_tile[m_idx][b_idx * 2u]; + let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u]; + + results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + + f32(dot(a_data1, b_dequantized[1])); + } + } + workgroupBarrier(); + } + + // Write the results. + for (var m_idx = 0u; m_idx < kTileM; m_idx++) { + write_output(batch, row + m_idx, col + local_idx, output_element_t(results[m_idx])); + } +} // MAIN diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 8151f9fb3dcc7..a7df3b7bbec54 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -310,6 +310,9 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { constexpr float abs_error = 0.1f * 1.02f; constexpr float rel_error = 0.02f * 1.02f; TestMatMul8BitsTyped(abs_error, rel_error); + + // Test case where K (260) is divisible by 16 but not by the block size (32). + TestMatMul8BitsTyped(); } TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { @@ -357,6 +360,9 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); + + // Test case where K (260) is divisible by 16 but not by the block size (32). + TestMatMul8BitsTyped(); } #if defined(USE_WEBGPU) @@ -367,6 +373,9 @@ TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { TestMatMul8BitsTyped(abs_error, rel_error); TestMatMul8BitsTyped(abs_error, rel_error); TestMatMul8BitsTyped(abs_error, rel_error); + + // Test case where K (260) is divisible by 16 but not by the block size (32). + TestMatMul8BitsTyped(abs_error, rel_error); } #endif From 5af86e52b229cbd654782e3bee949707b03a9f93 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 16 Jul 2025 15:27:05 -0700 Subject: [PATCH 47/49] Update docker images (#25418) 1. Update the docker images to install system updates(per vulnerability management requirements) 2. Disable DNNL pipelines since a. There was no active development. b. The code is incompatible with CMake 4.x. 3. Disable migraphx pipeline due to license issues(conda is not free unless you only use conda-forge packages). 4. Change all UBI8 based images to use AlmaLinux8. I will make the base images public. They are under internal review. --- .github/workflows/linux-dnnl.yml | 40 ------ .github/workflows/linux_migraphx_ci.yml | 40 ------ ...ows_x64_release_dnnl_build_x64_release.yml | 132 ------------------ ...-gpu-tensorrt-cuda-minimal-ci-pipeline.yml | 4 +- .../py-cuda-package-test-pipeline.yml | 2 +- .../stages/java-cuda-packaging-stage.yml | 4 +- .../jobs/py-linux-cuda-package-test-job.yml | 4 +- .../stages/py-gpu-packaging-stage.yml | 4 +- .../linux/docker/Dockerfile.manylinux2_28_cpu | 2 +- .../docker/Dockerfile.manylinux2_28_rocm | 2 +- .../docker/Dockerfile.manylinux2_28_webgpu | 2 +- .../inference/aarch64/default/cpu/Dockerfile | 2 +- .../inference/aarch64/python/cpu/Dockerfile | 2 +- .../python/cpu/scripts/install_centos.sh | 13 +- .../python/cpu/scripts/install_deps.sh | 22 +-- .../inference/x86_64/default/cpu/Dockerfile | 2 +- .../x86_64/default/cuda12/Dockerfile | 2 +- .../inference/x86_64/python/cpu/Dockerfile | 2 +- .../python/cpu/scripts/install_centos.sh | 13 +- .../python/cuda/scripts/install_centos.sh | 13 +- .../x86_64/python/openvino/Dockerfile | 2 +- 21 files changed, 22 insertions(+), 287 deletions(-) delete mode 100644 .github/workflows/linux-dnnl.yml delete mode 100644 .github/workflows/linux_migraphx_ci.yml delete mode 100644 .github/workflows/windows_x64_release_dnnl_build_x64_release.yml diff --git a/.github/workflows/linux-dnnl.yml b/.github/workflows/linux-dnnl.yml deleted file mode 100644 index da393c1af3cee..0000000000000 --- a/.github/workflows/linux-dnnl.yml +++ /dev/null @@ -1,40 +0,0 @@ -# This workflow builds and tests the ONNX Runtime for Linux for DNNL EP -# It leverages a reusable workflow (`reusable_linux_build.yml`) to handle the core build and test logic -# within Docker containers, ensuring a consistent environment. -# This file is very similar to linux_ci.yml, but much simpler - - -name: Linux DNNL CI - -on: - push: - branches: [main, 'rel-*'] - pull_request: - branches: [main, 'rel-*'] - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} - cancel-in-progress: true - -permissions: - contents: read - packages: write - attestations: write - id-token: write - -jobs: - build-linux-x64-release-dnnl: - name: Build Linux x64 Release (DNNL EP) - uses: ./.github/workflows/reusable_linux_build.yml - with: - pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" - build_config: Release - architecture: x64 - dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu - docker_image_repo: onnxruntimecpubuildpythonx64 - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --build_nuget' - python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' - execution_providers: 'dnnl' - secrets: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/linux_migraphx_ci.yml b/.github/workflows/linux_migraphx_ci.yml deleted file mode 100644 index ee5e8bf12d651..0000000000000 --- a/.github/workflows/linux_migraphx_ci.yml +++ /dev/null @@ -1,40 +0,0 @@ -# This workflow builds and tests the ONNX Runtime for Linux for migraphx EP -# It leverages a reusable workflow (`reusable_linux_build.yml`) to handle the core build and test logic -# within Docker containers, ensuring a consistent environment. -# This file is very similar to linux_ci.yml, but much simpler - - -name: Linux MigraphX CI - -on: - push: - branches: [main, 'rel-*'] - pull_request: - branches: [main, 'rel-*'] - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} - cancel-in-progress: true - -permissions: - contents: read - packages: write - attestations: write - id-token: write - -jobs: - build-linux-x64-release-migraphx: - name: Build Linux x64 Release (migraphx EP) - uses: ./.github/workflows/reusable_linux_build.yml - with: - pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" - build_config: Release - architecture: x64 - dockerfile_path: tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile - docker_image_repo: onnxruntimetrainingmigraphx-cibuild-rocm - extra_build_flags: '--enable_training --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --rocm_version=6.4 --rocm_home /opt/rocm --nccl_home /opt/rocm --enable_nccl --skip_submodule_sync' - run_tests: false - execution_providers: 'migraphx' - secrets: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml b/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml deleted file mode 100644 index 4c74505ad183d..0000000000000 --- a/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml +++ /dev/null @@ -1,132 +0,0 @@ -name: windows_x64_dnnl_release - -on: - push: - branches: [main, 'rel-*'] - pull_request: - branches: [main] - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} - cancel-in-progress: true - -jobs: - build_x64_dnnl_release: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - timeout-minutes: 300 - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - uses: actions/cache@v4 - id: onnx-node-tests-cache - with: - path: ${{ github.workspace }}/js/test/ - key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --build_java --build_nodejs --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_dnnl - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" - - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md - - env: - OrtPackageId: Microsoft.ML.OnnxRuntime - OnnxRuntimeBuildDirectory: ${{ github.workspace }}\build - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: 'true' diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index 0ec05909b846f..caef61f68ab6b 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -39,9 +39,9 @@ variables: - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 021f7c5ece140..b10d15432ed5b 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 cuda_version: '12.2' - stage: Republish_Wheels diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 3f800212509de..1d5fb8e682b73 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -137,9 +137,9 @@ stages: value: false - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 timeoutInMinutes: 60 steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 4e42afe0da96e..e7cd1bf536afc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -44,9 +44,9 @@ jobs: - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index c865048456f3f..c75b44d56e7fd 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -67,9 +67,9 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 ${{ if eq(parameters.cuda_version, '12.2') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 - ${{ if eq(parameters.enable_windows_dml, true) }}: - ${{ each python_version in parameters.PythonVersions }}: diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 02938f015ec57..db8668fa9eafe 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 22e6ca5f50d13..d20da1867926b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 ARG ROCM_VERSION=6.2.3 #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index 8749502461ac5..0ff52a0a75dc8 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index 5db0e32e0df8b..e6e362ade897d 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14_dotnet:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250714.2 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 20b9a6c224120..267fc1e661242 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250714.2 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh index d0b58ed28b8c9..1ced7cd2f90c8 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_centos.sh @@ -4,15 +4,4 @@ set -e os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" -dnf install -y glibc-langpack-\* -yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget - -echo "installing rapidjson for AzureEP" -wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz -tar zxvf v1.1.0.tar.gz -cd rapidjson-1.1.0 -mkdir build -cd build -cmake .. -cmake --install . -cd ../.. +dnf install -y glibc-langpack-\* which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh index 81de2abf3ff87..b1e74df6fb92b 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/install_deps.sh @@ -13,24 +13,4 @@ export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do ${PYTHON_EXE} -m pip install -r requirements.txt -done - -# No release binary for ccache aarch64, so we need to build it from source. -if ! [ -x "$(command -v ccache)" ]; then - ccache_url="https://github.com/ccache/ccache/archive/refs/tags/v4.8.tar.gz" - pushd . - curl -sSL --retry 5 --retry-delay 10 --create-dirs --fail -L -o ccache_src.tar.gz $ccache_url - mkdir ccache_main - cd ccache_main - tar -zxf ../ccache_src.tar.gz --strip=1 - - mkdir build - cd build - cmake -DCMAKE_INSTALL_PREFIX=/usr/local _DCMAKE_BUILD_TYPE=Release .. - make - make install - which ccache - popd - rm -f ccache_src.tar.gz - rm -rf ccache_src -fi +done \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 6052096877ac5..7981210af14a1 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14_dotnet:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250714.2 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 764a79135d7a3..894802dfc8675 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250714.2 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 7590d5dd18347..fc376e33d6d10 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh index d0b58ed28b8c9..1ced7cd2f90c8 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/scripts/install_centos.sh @@ -4,15 +4,4 @@ set -e os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" -dnf install -y glibc-langpack-\* -yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget - -echo "installing rapidjson for AzureEP" -wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz -tar zxvf v1.1.0.tar.gz -cd rapidjson-1.1.0 -mkdir build -cd build -cmake .. -cmake --install . -cd ../.. +dnf install -y glibc-langpack-\* which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh index d0b58ed28b8c9..1ced7cd2f90c8 100755 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/scripts/install_centos.sh @@ -4,15 +4,4 @@ set -e os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" -dnf install -y glibc-langpack-\* -yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget - -echo "installing rapidjson for AzureEP" -wget https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.tar.gz -tar zxvf v1.1.0.tar.gz -cd rapidjson-1.1.0 -mkdir build -cd build -cmake .. -cmake --install . -cd ../.. +dnf install -y glibc-langpack-\* which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index dd049d7260bdf..607b3e693b624 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -1,5 +1,5 @@ # Use the specified UBI8 base image with GCC 14 -ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1" +ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2" FROM ${BASEIMAGE} ARG BUILD_UID=1000 From 3a5f75bf4e29899c29919a266ca8f31016c4f31c Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 16 Jul 2025 20:36:48 -0400 Subject: [PATCH 48/49] revert qnn sdk version (#25426) Fixes Error: Could not find com.qualcomm.qti:qnn-runtime:2.36.1 The nuget packaging pipeline fails with Could not find com.qualcomm.qti:qnn-runtime:2.36.1 https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=866702&view=results --- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 2 +- .../azure-pipelines/templates/android-java-api-aar-test.yml | 2 +- .../github/azure-pipelines/templates/android-java-api-aar.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 202aa61da0b80..aa25e3f31166a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 74f7f782fe1b2..ab779e164b36e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 92e862bd79008..110f83ff587c8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 5b48a14e2afc3..535784933a087 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: is1ES displayName: Is 1ES pipeline From 7fe617ccbf9ffc281d2bd1b63c61c4349b56541a Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 17 Jul 2025 15:32:26 +1000 Subject: [PATCH 49/49] Restore ability to handle non-hex string in device discovery vendor/device id. (#25427) ### Description Restore ability to handle "VEN_QCOM" from an ACPI entry. ### Motivation and Context --- .../core/platform/windows/device_discovery.cc | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index fa645939a6395..ff904ddb3e7e0 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -77,6 +77,29 @@ struct DriverInfo { } }; +bool IsHexString(const std::wstring& str) { + for (const wchar_t& c : str) { + if (!((c >= L'0' && c <= L'9') || (c >= L'A' && c <= L'F') || (c >= L'a' && c <= L'f'))) { + return false; + } + } + return true; +} + +// Converts a wide string ACPI (up to 4 characters) representing a hardware ID component from into a uint32_t. +// e.g., "QCOM" from "VEN_QCOM". The conversion is done in a little-endian manner, meaning the first character +// of the string becomes the least significant byte of the integer, and the fourth character +// becomes the most significant byte. +uint32_t AcpiWStringToUint32Id(const std::wstring& vendor_name) { + uint32_t vendor_id = 0; + for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { + // For little-endian, place each character at the appropriate byte position + // First character goes into lowest byte, last character into highest byte + vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); + } + return vendor_id; +} + uint64_t GetDeviceKey(uint32_t vendor_id, uint32_t device_id) { return (uint64_t(vendor_id) << 32) | device_id; } @@ -134,25 +157,35 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // PCI\VEN_xxxx&DEV_yyyy&... // ACPI\VEN_xxxx&DEV_yyyy&... if we're lucky. // ACPI values seem to be very inconsistent, so we check fairly carefully and always require a device id. - const auto get_id = [](const std::wstring& hardware_id, const std::wstring& prefix) -> uint32_t { + const auto get_id = [](bool is_pci, const std::wstring& hardware_id, const std::wstring& prefix) -> uint32_t { if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); + if (id.size() == 4) { - return static_cast(std::stoul(id, nullptr, 16)); + if (is_pci || IsHexString(id)) { + // PCI entries have hex numbers. ACPI might. + return static_cast(std::stoul(id, nullptr, 16)); + } else { + // ACPI can have things like "VEN_QCOM". Fallback to using this conversion where the characters + // are converted in little-endian order. + return AcpiWStringToUint32Id(id); + } } } return 0; }; - // Processor ID should come from CPUID mapping. + const bool is_pci = std::wstring(buffer, 3) == std::wstring(L"PCI"); + if (guid == GUID_DEVCLASS_PROCESSOR) { + // Processor ID should come from CPUID mapping. vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId(); } else { - vendor_id = get_id(buffer, L"VEN_"); + vendor_id = get_id(is_pci, buffer, L"VEN_"); } - device_id = get_id(buffer, L"DEV_"); + device_id = get_id(is_pci, buffer, L"DEV_"); // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) {