diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 0d1f47f195ba5..84e3132b1deec 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -367,23 +367,23 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if (onnxruntime_target_platform STREQUAL "ARM64EC" OR onnxruntime_target_platform STREQUAL "ARM64") - message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") - onnxruntime_fetchcontent_declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - EXCLUDE_FROM_ALL - PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") + onnxruntime_fetchcontent_declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + EXCLUDE_FROM_ALL + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch + FIND_PACKAGE_ARGS NAMES cpuinfo + ) else() - onnxruntime_fetchcontent_declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - EXCLUDE_FROM_ALL - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + onnxruntime_fetchcontent_declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + EXCLUDE_FROM_ALL + FIND_PACKAGE_ARGS NAMES cpuinfo + ) endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CPUINFO_PROJ}) diff --git a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch b/cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch similarity index 53% rename from cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch rename to cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch index 7785621965b00..23ceeb8f758cc 100644 --- a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch +++ b/cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch @@ -1,5 +1,5 @@ diff --git a/include/cpuinfo.h b/include/cpuinfo.h -index 6eb4b8c..4346a5a 100644 +index f1d35d4..9e454d2 100644 --- a/include/cpuinfo.h +++ b/include/cpuinfo.h @@ -18,7 +18,7 @@ @@ -20,16 +20,3 @@ index 6eb4b8c..4346a5a 100644 #define CPUINFO_ARCH_ARM64 1 #endif -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index de2f6cc..c3a7835 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -175,7 +175,7 @@ static struct woa_chip_info* get_system_info_from_registry(void) { - if (chip_info == NULL) { - /* No match was found, so print a warning and assign the unknown - * case. */ -- cpuinfo_log_error( -+ cpuinfo_log_debug( - "Unknown chip model name '%ls'.\nPlease add new Windows on Arm SoC/chip support to arm/windows/init.c!", - text_buffer); - } else { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs index 098a18b7444cf..2467475b6b189 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs @@ -23,8 +23,8 @@ internal enum ErrorCode ModelLoaded = 8, NotImplemented = 9, InvalidGraph = 10, - ShapeInferenceNotRegistered = 11, - RequirementNotRegistered = 12, + ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL + RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED } /// diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index da9735aa4e418..8cf6420f2d0f7 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 0d920ab7dac89..21aa797ce16eb 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -758,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr break; } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + + Ort::Value tensor(ort_value); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + size_t element_size = 0; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); return Ort::Status(err_msg.c_str(), ORT_FAIL); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d87e9e083185b..2899a219bdda0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -264,6 +264,7 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, ORT_MODEL_LOAD_CANCELED, ORT_MODEL_REQUIRES_COMPILATION, + ORT_NOT_FOUND, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -275,6 +276,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, ORT_OP_ATTR_GRAPH, + ORT_OP_ATTR_TENSOR, } OrtOpAttrType; //! @} @@ -6031,6 +6033,11 @@ struct OrtApi { * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is * returned if `num_attributes` is less than the number of node attributes. * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. @@ -6042,14 +6049,36 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute_name The name of the attribute - * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists. + * For an unset optional attribute, `attribute` is set to NULL and a non-error status is + * returned. For an invalid attribute name, `attribute` is set to NULL and an error status with + * code ORT_NOT_FOUND is returned. + * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); + + /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. + * + * \param[in] node The OrtNode instance. + * \param[in] attribute The OrtOpAttr instance. + * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. + Must be freed with OrtApi::ReleaseValue. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * diff --git a/java/src/main/java/ai/onnxruntime/OrtException.java b/java/src/main/java/ai/onnxruntime/OrtException.java index 5ec58ea137124..06c3d3cbc770c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtException.java +++ b/java/src/main/java/ai/onnxruntime/OrtException.java @@ -81,11 +81,17 @@ public enum OrtErrorCode { /** The ONNX graph is invalid. */ ORT_INVALID_GRAPH(10), /** The ORT execution provider failed. */ - ORT_EP_FAIL(11); + ORT_EP_FAIL(11), + /** Model load was canceled. */ + ORT_MODEL_LOAD_CANCELED(12), + /** Model requires compilation. */ + ORT_MODEL_REQUIRES_COMPILATION(13), + /** Item was not found. */ + ORT_NOT_FOUND(14); private final int value; - private static final OrtErrorCode[] values = new OrtErrorCode[12]; + private static final OrtErrorCode[] values = new OrtErrorCode[15]; static { for (OrtErrorCode ot : OrtErrorCode.values()) { diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) { return 10; case ORT_EP_FAIL: return 11; + case ORT_MODEL_LOAD_CANCELED: + return 12; + case ORT_MODEL_REQUIRES_COMPILATION: + return 13; + case ORT_NOT_FOUND: + return 14; default: return -1; // Unknown error code } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 71161c120a306..39b0ccdc7fe6a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -50,6 +50,7 @@ struct WebgpuAttentionParameters { v_hidden_size_(parameters.kv_hidden_size), v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), num_heads_(parameters.num_heads), + is_unidirectional_(true), do_rotary_(parameters.do_rotary), scale_(parameters.scale), seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index c9e182bf10f2f..dbe2614099be1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -275,8 +275,23 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; +)MAIN_FN"; - for(var k_start = 0u; k_start < uniforms.total_sequence_length; k_start+=capped_sg_size) + if (is_unidirectional_) { + // If attention is unidirectional, set the loop bound to enforce causal masking. + shader.MainFunctionBody() << R"MAIN_FN( + let max_causal_len_for_workgroup = uniforms.past_sequence_length + + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; + let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); +)MAIN_FN"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + let loop_bound = uniforms.total_sequence_length; +)MAIN_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + for(var k_start = 0u; k_start < loop_bound; k_start+=capped_sg_size) { workgroupBarrier(); loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); @@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx); } - let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0); + let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0); // Neuter qk values where K is out of bounds. qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length); qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length); @@ -903,7 +918,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, is_fp16, parameters.head_size_, parameters.num_heads_}; + FlashAttentionProgram program{"FlashAttention", + has_attention_bias, + is_qualcomm, + is_fp16, + parameters.head_size_, + parameters.num_heads_, + parameters.is_unidirectional_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -916,12 +937,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, is_qualcomm) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, - {static_cast(parameters.is_gqa_ ? 1 : 0)}, + {static_cast(parameters.is_unidirectional_)}, {static_cast(parameters.n_reps)}, {alpha}, {num_seq_tile}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3f79b80fb73bc..9908b33a38372 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -39,13 +39,15 @@ class FlashAttentionProgram final : public Program { bool is_qualcomm, bool is_fp16, int qkv_head_size, - int qkv_num_heads) + int qkv_num_heads, + bool is_unidirectional) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), is_fp16_(is_fp16), qkv_head_size_(qkv_head_size), - qkv_num_heads_(qkv_num_heads) { + qkv_num_heads_(qkv_num_heads), + is_unidirectional_(is_unidirectional) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -54,7 +56,7 @@ class FlashAttentionProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"is_gqa", ProgramUniformVariableDataType::Uint32}, + {"is_unidirectional", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); @@ -65,6 +67,7 @@ class FlashAttentionProgram final : public Program { bool is_fp16_; int qkv_head_size_; int qkv_num_heads_; + bool is_unidirectional_; }; class FlashAttentionDecodeQKTProgram final : public Program { diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6383d29d7a2bc..504b102e782fd 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -251,6 +251,16 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; + /// + /// Gets the node's 'TENSOR' attribute as an OrtValue. + /// + /// Node's 'TENSOR' attribute. + /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, + /// only if the attribute is of type 'TENSOR' + /// A status indicating success or an error. + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, + OrtValue*& value) const = 0; + /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..eb7fb6937c29e 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -87,6 +87,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, } } +#if !defined(ORT_MINIMAL_BUILD) +static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) { + const ONNX_NAMESPACE::OpSchema* op_schema = node.Op(); + if (op_schema == nullptr) { + return false; + } + + auto attr_schema_iter = op_schema->attributes().find(attr_name); + if (attr_schema_iter == op_schema->attributes().end()) { + return false; // Not an attribute for this operator type. + } + + const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second; + + return !attr_schema.required; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // // EpNode // @@ -230,6 +248,32 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& graph_viewer = ep_graph_->GetGraphViewer(); + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + result = tensor_attribute_value.release(); + return Status::OK(); +} + Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); @@ -268,13 +312,20 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } -const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { +const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const { auto iter = attributes_map_.find(name); - if (iter == attributes_map_.end()) { - return nullptr; - } else { + if (iter != attributes_map_.end()) { + is_unset_optional_attr = false; return reinterpret_cast(iter->second.get()); } + +#if !defined(ORT_MINIMAL_BUILD) + is_unset_optional_attr = IsOptionalAttribute(node_, name); +#else + // This is not properly set in a minimal build because it does not have access to the operator schema. + is_unset_optional_attr = false; +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } const std::string& EpNode::GetEpName() const { diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..be78d77360cb8 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,6 +183,9 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + OrtValue*& attr_tensor) const override; + // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; @@ -209,8 +212,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; - // Helper that gets the node's attributes by name. - const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the + // output parameter `is_unset_optional_attr` to true if this is an unset optional attribute. + const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const; // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5d84e48182bfe..d3795d911b22f 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); + } + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index dbf86e2bb7fc7..7aba3b9549f23 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -67,6 +67,7 @@ #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -271,6 +272,7 @@ InlinedVector> GenerateTransformers( // It runs unconditionally in InferenceSession::TransformGraph() prior to Level1 optimizers. // We also put it here with other Level1 optimizers so that it can fix things up after their changes. transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); } // add __backwardpass attribute to nodes after YieldOp, ROCm-only diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc new file mode 100644 index 0000000000000..a8b9814f1020c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/common/common.h" +#include "core/util/qmath.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { +bool WhereDummyDq::SatisfyCondition(const Graph& graph, const Node& node) const { + if (!(node.OpType() == "Where")) { + return false; + } + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName); + bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName); + + // WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input + if (is_p1_dq && !parent_node_2) { + return (where_inputs[2]->Shape()->dim_size() == 0); + } + if (!parent_node_1 && is_p2_dq) { + return (where_inputs[1]->Shape()->dim_size() == 0); + } + return false; +} + +Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const { + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + // With SatisfyCondition, we must have one DQ and one initializer + const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2; + int const_idx = parent_node_1 ? 2 : 1; + + const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto); + const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto); + + // Dummy data initializer. + ONNX_NAMESPACE::TensorProto dummy_data_proto; + dummy_data_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_data")); + // Set data type to dq node's zp dtype + dummy_data_proto.set_data_type(dq_node_zp_proto->data_type()); + + // Dummy zero point initializer. + ONNX_NAMESPACE::TensorProto dummy_zp_proto; + dummy_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_zp")); + dummy_zp_proto.set_data_type(dq_node_zp_proto->data_type()); + + switch (dummy_zp_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + int8_t zp = 0; + int8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + uint8_t zp = 0; + uint8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + int16_t zp = 0; + int16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + uint16_t zp = 0; + uint16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + default: + LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16"; + return Status::OK(); + } + + // Set dummy scale to the original value + const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr; + graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto); + Initializer initializer(graph, *const_node_data_proto, graph.ModelPath()); + if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) { + // WhereDummyDq fills the const value to the dummy DQ's scale + LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar"; + return Status::OK(); + } + + // Dummy scale initializer. + ONNX_NAMESPACE::TensorProto dummy_scale_proto; + dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale")); + dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type()); + switch (initializer.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + float* where_const_scalar = initializer.data(); + dummy_scale_proto.set_raw_data(where_const_scalar, sizeof(float)); + break; + } + default: + LOGS(logger, WARNING) << "Currently support scalar with FLOAT"; + return Status::OK(); + } + + // Start editing the graph + NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_data_proto); + NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_scale_proto); + NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_zp_proto); + + ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto(*const_node_data_proto); + dummy_dq_type_proto.mutable_tensor_type()->set_elem_type(const_node_data_proto->data_type()); + NodeArg& dummy_dq_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), &dummy_dq_type_proto); + Node& dummy_dq_node = + graph.AddNode( + graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), + QDQ::DQOpName, + "DeQuantizeLinear from WhereDummyDq GraphTransformer", + {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg}, + {&dummy_dq_arg}, + nullptr, + dq_node->Domain()); + + node.MutableInputDefs()[const_idx] = &dummy_dq_arg; + if (graph.GetConsumerNodes(where_inputs[const_idx]->Name()).size() == 0) { + graph.RemoveInitializedTensor(where_inputs[const_idx]->Name()); + } + graph.AddEdge(dummy_dq_node.Index(), node.Index(), 0, const_idx); + modified = true; + + return Status::OK(); +} + +Status WhereDummyDq::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (this->SatisfyCondition(graph, node)) { + ORT_RETURN_IF_ERROR(WhereDummyDq::InsertDummyDQ(node, graph, modified, logger)); + } + } + + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h new file mode 100644 index 0000000000000..3260a865f8c4b --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + @Class WhereDummyDq + + Graph transformer that inserts a dummy DQ on Where node's initializer input + to form Node Unit when Where node has one DQ and one scalar initializer input +*/ +class WhereDummyDq : public GraphTransformer { + public: + WhereDummyDq() noexcept : GraphTransformer("WhereDummyDq") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool SatisfyCondition(const Graph& graph, const Node& node) const; + Status InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..1cac133ab0c2c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const { int current_device; auto hip_err = hipGetDevice(¤t_device); if (hip_err == hipSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } #endif } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 84c7fe3e4d4ab..e9b1fe0f39da5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -151,10 +151,10 @@ struct MIGraphX_Provider : Provider { const OrtSessionOptions& session_options, const OrtLogger& logger, std::unique_ptr& ep) override { + ORT_UNUSED_PARAMETER(num_devices); const ConfigOptions* config_options = &session_options.GetConfigOptions(); std::array configs_array = {&provider_options, config_options}; - const void* arg = reinterpret_cast(&configs_array); auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); @@ -181,26 +181,47 @@ struct MigraphXEpFactory : OrtEpFactory { const char* ep_name, OrtHardwareDeviceType hw_type, const OrtLogger& default_logger_in) - : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} { + : ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. // Ex: a factory that supports NPU should have a different than a factory that supports GPU. - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->version.c_str(); + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among @@ -212,7 +233,7 @@ struct MigraphXEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -237,18 +258,56 @@ struct MigraphXEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto* factory = static_cast(this_ptr); + + *allocator = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, + "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { + // should never be called as we don't implement CreateAllocator + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // not implemented + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + auto* factory = static_cast(this_ptr); + + *stream = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + } + const OrtApi& ort_api; const OrtLogger& default_logger; const std::string ep_name; const std::string vendor{"AMD"}; + const std::string version{"1.0.0"}; // MigraphX EP version const uint32_t vendor_id{0x1002}; const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0152ad27c0ba2..e248034f225ec 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -8,6 +8,13 @@ namespace onnxruntime { namespace qnn { +namespace { +bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) { + const NodeArg& arg = node_io_def.node_arg; + return !arg.Exists() || arg.Name().empty(); +} +} // namespace + std::string BaseOpBuilder::GetOpBuilderType() const { return op_builder_type_; } @@ -46,12 +53,18 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, const auto& inputs = node_unit.Inputs(); const auto& outputs = node_unit.Outputs(); for (auto input : inputs) { + if (IsOptionalNodeUnitIODef(input)) { + continue; + } TensorInfo tensor_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, tensor_info)); Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; input_qnn_dtypes.push_back(qnn_data_type); } for (auto output : outputs) { + if (IsOptionalNodeUnitIODef(output)) { + continue; + } TensorInfo tensor_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, tensor_info)); Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 7e17addf2f577..51c38b4483cb9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -192,6 +192,50 @@ bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) { return true; } +bool IsEquationReduceSumMulBroadcastX(const Equation& equation) { + // E.g., bhwc,wkc->bhwk + const auto& [term_1, term_2, result] = equation; + if (term_1.size() != 4) { + return false; + } + if (term_2.size() != 3) { + return false; + } + if (result.size() != 4) { + return false; + } + + // Check contraction over last axis (c) + char c1 = term_1[3]; + char c2 = term_2[2]; + if (c1 != c2) { + return false; + } + + // Check w axis alignment + if (term_1[2] != term_2[0]) { + return false; + } + if (term_1[2] != result[2]) { + return false; + } + + // Check k axis alignment + if (term_2[1] != result[3]) { + return false; + } + + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + + return true; +} + /** * @brief Sets the parameter tensor names for a MatMul op. * @@ -305,6 +349,113 @@ Status CreateMatMulTransposeAll( return Status::OK(); } +/** + * @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model. + * @param node_unit The NodeUnit representing the ONNX node to be converted. + * @param do_op_validation A boolean flag indicating whether to perform operation validation. + * @return Status indicating success or failure of the operation. + */ +Status CreateReduceSumMulBroadcastX( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + std::vector&& input_names, + bool do_op_validation) { + // Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'. + // Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce. + onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out)); + const std::vector& shape_in0 = tensor_info_in0.shape; + const std::vector& shape_in1 = tensor_info_in1.shape; + ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4"); + ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3"); + const std::vector new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]}; + const std::string reshape_out_name = input_names[0] + "_reshaped"; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode( + /*input_name=*/input_names[0], + /*output_name=*/reshape_out_name, + /*input_shape=*/shape_in0, + /*output_shape=*/new_shape_in0, + /*tensor_data_type=*/tensor_info_in0.qnn_data_type, + /*quantize_param=*/tensor_info_in0.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0]))); + + // Multiply: reshaped in0 * in1 + // The output shape of the multiplication is determined by broadcasting the reshaped in0 of + // (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c). + const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul"; + std::vector shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]}; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name, + QNN_TENSOR_TYPE_NATIVE, + tensor_info_in0.qnn_data_type, + tensor_info_in0.quant_param.Copy(), + std::move(shape_out_mul)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/mul_out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY, + /*input_names=*/{reshape_out_name, input_names[1]}, + /*output_names=*/{mul_out_name}, + /*param_tensor_names=*/{}, + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create Mul node"); + + std::vector param_tensor_names{}; + + // ReduceSum on last axes={4}, keep_dims=False + // Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c), + // which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk). + std::vector axes_shape{SafeInt(1)}; + std::vector axes_value{SafeInt(4)}; + onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_AXES, + std::move(axes_shape), + std::move(axes_value)); + param_tensor_names.push_back(param_axes.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)), + "CreateReduceSumMulBroadcastX: failed to add param axes"); + + Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; + keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; + keep_dims_scalar.bool8Value = SafeInt(0); + onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS, + keep_dims_scalar); + param_tensor_names.push_back(param_keep_dims.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)), + "CreateReduceSumMulBroadcastX: failed to add param keep_dims"); + + const std::string out_name = node_unit.Outputs()[0].node_arg.Name(); + Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name, + out_tensor_type, + tensor_info_out.qnn_data_type, + tensor_info_out.quant_param.Copy(), + std::move(tensor_info_out.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_REDUCE_SUM, + /*input_names=*/{mul_out_name}, + /*output_names=*/{out_name}, + /*param_tensor_names=*/std::move(param_tensor_names), + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create ReduceSum node"); + + return Status::OK(); +} + } // namespace namespace onnxruntime { @@ -356,9 +507,20 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, if (!IsEquationMatMul(parsed_equation.value()) && !IsEquationMatMulTransposeY(parsed_equation.value()) && !IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) && - !IsEquationMatMulTransposeAll(parsed_equation.value())) { + !IsEquationMatMulTransposeAll(parsed_equation.value()) && + !IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } + if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + // QAIRT 3.36.1: Failed to validate on GPU. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU"); + } + if (node_unit.Inputs()[0].quant_param.has_value()) { + // QAIRT 3.36.1: Failed to finalize QNN graph 1002. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs"); + } + } return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } @@ -408,6 +570,11 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*node_unit=*/node_unit, /*input_names=*/std::move(input_names), /*do_op_validation=*/do_op_validation)); + } else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index a22d21d8d798b..bdeea726a2cf5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << ","; } - auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << "\n " << alignment << name << ": "; + // The actual variable type for the uniform variable depends on the data type (T) and length (N). + // + // For T in [i32, u32, f32]: + // - If N == 1, the type is simply i32, u32, or f32. + // - If 2 < N <= 4, the type is vecN, vecN, or vecN where N is the length. + // - If N > 4, the type is array, ceil(N / 4)>. + // + // For T is f16: + // - If N == 1 or N == 2, the type is u32. + // - If 2 < N <= 8, the type is vecX where X is ceil(N / 2). + // - If N > 8, the type is array, X> where X is ceil(N / 8). + // + // Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent + // 2 f16 values. + + if (data_type == ProgramUniformVariableDataType::Float16) { + data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32 + length = (length + 1) / 2; // each u32 can hold 2 f16 values + } + ss << "\n " << name << ": "; if (length > 4) { - if (data_type == ProgramUniformVariableDataType::Float16) { - size_t array_size = (length + 7) / 8; - ss << "array, " << array_size << ">"; - } else { - size_t array_size = (length + 3) / 4; - ss << "array, " << array_size << ">"; - } + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; } else if (length > 1) { ss << "vec" << length << "<" << data_type << ">"; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 2aba2a59d157f..78c98ab26f5b8 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -17,18 +17,34 @@ template || std::is_same_v>> std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { - // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniforms.", 0) == 0) { - if (rank > 4) { - if constexpr (std::is_integral_v) { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + if (var.starts_with("uniforms.")) { + if (is_f16) { + if (rank > 8) { + // array, N> + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]"); } else { - return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]"); + } + } else if (rank > 2) { + // vecN + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 2, "])[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]"); } } else { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + // u32 + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, ")[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, ")[(", idx, ") % 2]"); + } + } + } else { + if (rank > 4) { + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); } else { return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bd79a627df22..a9557f7b9aa87 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { continue; } - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // Calculate the size and alignment of the uniform variable. + // // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + // + // For f16: + // - length > 8 : array, N> (align 16) (size 16 * N, N = ceil(length / 8)) + // - length == 7 or 8: vec4 (align 16) (size 16) + // - length == 5 or 6: vec3 (align 16) (size 12) + // - length == 3 or 4: vec2 (align 8) (size 8) + // - length == 1 or 2: u32 (align 4) (size 4) + // + // For other types (i32, u32, f32): + // - length > 4 : array, N> (align 16) (size 16 * N, N = ceil(length / 4)) + // - length == 4 : vec4 (align 16) (size 16) + // - length == 3 : vec3 (align 16) (size 12) + // - length == 2 : vec2 (align 8) (size 8) + // - length == 1 : T (align 4) (size 4) + // + + const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t variable_alignment = 4; // default alignment for scalar types + size_t variable_size = 4; // default size for scalar types + + if (is_f16) { + if (length > 6) { + variable_alignment = 16; + variable_size = 16 * ((length + 7) / 8); + } else if (length > 4) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 2) { + variable_alignment = 8; + variable_size = 8; + } + } else { + if (length > 3) { + variable_alignment = 16; + variable_size = 16 * ((length + 3) / 4); + } else if (length > 2) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 1) { + variable_alignment = 8; + variable_size = 8; + } + } + current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment; uniform_and_offsets.emplace_back(uniform, current_offset); - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + current_offset += variable_size; } // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 37f4fe7312bb4..4c7b4d7b29c2f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2993,7 +2993,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_result_maybenull_ const OrtOpAttr** attribute) { API_IMPL_BEGIN if (attribute == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); @@ -3004,14 +3005,30 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); } - *attribute = ep_node->GetAttribute(attribute_name); + bool is_unset_optional_attr = false; + *attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr); - if (*attribute) { + if (*attribute || is_unset_optional_attr) { return nullptr; } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + std::ostringstream oss; + oss << "Node attribute does not exist: " << attribute_name; + return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str()); + } + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + return nullptr; API_IMPL_END } @@ -3052,6 +3069,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + *type = OrtOpAttrType::ORT_OP_ATTR_TENSOR; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -4034,6 +4055,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, + &OrtApis::Node_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..3eee174ff81f4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -678,7 +678,9 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 8f3b97c8c7786..6b3062205b52e 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "EPFail"); pybind11::register_exception(m, "ModelLoadCanceled"); pybind11::register_exception(m, "ModelRequiresCompilation"); + pybind11::register_exception(m, "NotFound"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw ModelLoadCanceled(std::move(msg)); case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: throw ModelRequiresCompilation(std::move(msg)); + case onnxruntime::common::StatusCode::NOT_FOUND: + throw NotFound(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index 86bc4a5da8d46..7680c06c59d79 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error { struct ModelRequiresCompilation : std::runtime_error { explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} }; +struct NotFound : std::runtime_error { + explicit NotFound(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..188edad572182 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,92 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, GetAttributeByName) { + // Load model with a single Conv that has no explicit attributes set. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // + // Pre-check + // + + // Original Conv has no explicit attributes but Graph::Resolve() fills in default values for + // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not + // have statically computable default values, so will not be filled in by Graph::Resolve(). + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_nodes = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + ASSERT_EQ(num_nodes, 1); + + std::vector nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + const OrtNode* conv_node = nodes[0]; + const char* op_type = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); + ASSERT_STREQ(op_type, "Conv"); + + size_t num_attrs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); + ASSERT_EQ(num_attrs, 2); + + std::vector attrs(num_attrs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); + for (const OrtOpAttr* attr : attrs) { + const char* attr_name_cstr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); + std::string_view attr_name = attr_name_cstr; + ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set + } + + // + // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 3: Get attribute that is known to be set. + // + { + const OrtOpAttr* attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + ASSERT_NE(attr, nullptr); + + OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); + ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); + + std::string auto_pad_val; + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + size_t total_attr_bytes = 0; + Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; + auto_pad_val.resize(total_attr_bytes); + + ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, + &total_attr_bytes)); + ASSERT_EQ(auto_pad_val, "NOTSET"); + } +} + // Check correctness of an OrtGraph that has external initializers. TEST(EpGraphTest, CheckModelExternalInitializers) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); @@ -220,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& outpu output_data.assign(output_values, output_values + num_output_elems); } +static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {3}; + std::vector input_data = {2, 3, 4}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'x' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("x"); + + // Run session and get outputs + std::array output_names{"y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 24); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. // Checks that the outputs of the serialized and original models are identical. TEST(EpGraphTest, SerializeToProto_Mnist) { @@ -350,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } } +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "constant_of_shape_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + static_cast(value_info); + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConstantOfShapeModel(original_model_path, output_original); + RunConstantOfShapeModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -892,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 98640bb2f6b4c..1baa6e529cbde 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -12,6 +12,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -3220,6 +3221,79 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 } +template +void TestWhereWithDqInput(bool is_dq_1, + bool is_dq_2, + int expected_num_where, + int expected_num_dq, + int expected_num_q, + bool expected_modified) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("WhereDummyDqTester", false, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + NodeArg* where_in1 = nullptr; + NodeArg* where_in2 = nullptr; + if (is_dq_1) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in1 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in1}); + } else { + where_in1 = builder.MakeInitializer({}, 0.0, 1.0); + } + if (is_dq_2) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in2 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in2}); + } else { + where_in2 = builder.MakeInitializer({}, 0.0, 1.0); + } + + // Where + auto* where_cond = builder.MakeInputBool({4, 3, 32}); + auto* where_out = builder.MakeIntermediate(); + builder.AddNode("Where", {where_cond, where_in1, where_in2}, {where_out}); + + // Q + auto* q_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_zp = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_out = builder.MakeOutput(); + builder.AddNode("QuantizeLinear", {where_out, q_scale, q_zp}, {q_out}); + + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + + auto where_optimizer = std::make_unique(); + bool modified = false; + ASSERT_STATUS_OK(where_optimizer->Apply(graph, modified, logger)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Where"], expected_num_where); + ASSERT_EQ(op_to_count["DequantizeLinear"], expected_num_dq); + ASSERT_EQ(op_to_count["QuantizeLinear"], expected_num_q); + ASSERT_EQ(modified, expected_modified); + + return; +}; + +TEST(QDQTransformerTests, WhereDummyDqTest) { + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); +} + TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis, diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index 5de7885a9452a..761ddf1975d15 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) { ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("migraphx_basic_test.onnx"); @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index d8dbbd799a427..11a3d5a083aab 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 7, 1, 7}; + const std::vector shape1{1, 9, 1, 7}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { const std::vector shape0{2, 3, 3, 4}; const std::vector shape1{3, 3, 4}; @@ -202,16 +215,16 @@ TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { /*tolerance=*/1e-4f); } -TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { - const std::vector shape0{1, 7, 1, 7}; - const std::vector shape1{1, 9, 1, 7}; +TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) { + const std::vector shape0{2, 3, 4, 5}; + const std::vector shape1{4, 6, 5}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bkhq,bchk->bchq", + /*equation=*/"bhwc,wkc->bhwk", /*tolerance=*/1e-4f); } @@ -299,6 +312,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-2f); +} + // // QNN HTP QDQ // @@ -375,6 +401,19 @@ TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) { /*tolerance=*/QDQTolerance()); } +// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002. +TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #if defined(_M_ARM64) @@ -474,6 +513,20 @@ TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) { /*tolerance=*/1e-4f); } +// TODO: Re-enable. Failed on QAIRT 3.36.1. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-4f); +} + #endif // defined(_M_ARM64) GPU tests } // namespace test diff --git a/onnxruntime/test/testdata/conv_default_attrs.onnx b/onnxruntime/test/testdata/conv_default_attrs.onnx new file mode 100644 index 0000000000000..fc7ee58dee15e Binary files /dev/null and b/onnxruntime/test/testdata/conv_default_attrs.onnx differ diff --git a/onnxruntime/test/testdata/make_conv_default_attrs.py b/onnxruntime/test/testdata/make_conv_default_attrs.py new file mode 100644 index 0000000000000..fc092bf8b25fb --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_default_attrs.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def main(): + inp_shape = (1, 2, 8, 8) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + + weight_data = [ + [[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]], + [[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias") + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + + onnx.checker.check_model(model, True) + onnx.save_model(model, "conv_default_attrs.onnx") + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..56fd3f1323e92 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -620,6 +620,7 @@ def generate_build_tree( ) generate_vcpkg_triplets_for_emscripten( build_dir, + configs, emscripten_root_path, not args.disable_rtti, not args.disable_wasm_exception_catching, @@ -627,25 +628,21 @@ def generate_build_tree( args.enable_address_sanitizer, ) elif args.android: - generate_android_triplets(build_dir, args.android_cpp_shared, args.android_api) + generate_android_triplets(build_dir, configs, args.android_cpp_shared, args.android_api) elif is_windows(): - generate_windows_triplets(build_dir, args.msvc_toolset) + generate_windows_triplets(build_dir, configs, args.msvc_toolset) elif is_macOS(): osx_target = args.apple_deploy_target if args.apple_deploy_target is None: osx_target = os.environ.get("MACOSX_DEPLOYMENT_TARGET") if osx_target is not None: log.info(f"Setting VCPKG_OSX_DEPLOYMENT_TARGET to {osx_target}") - generate_macos_triplets(build_dir, osx_target) + generate_macos_triplets(build_dir, configs, osx_target) else: # Linux, *BSD, AIX or other platforms - generate_linux_triplets(build_dir) + generate_linux_triplets(build_dir, configs) add_default_definition(cmake_extra_defines, "CMAKE_TOOLCHAIN_FILE", str(vcpkg_toolchain_path)) - vcpkg_install_options = generate_vcpkg_install_options(build_dir, args) - # VCPKG_INSTALL_OPTIONS is a CMake list. It must be joined by semicolons - # Therefore, if any of the option string contains a semicolon, it must be escaped - add_default_definition(cmake_extra_defines, "VCPKG_INSTALL_OPTIONS", ";".join(vcpkg_install_options)) # Choose the cmake triplet triplet = None if args.build_wasm: @@ -1251,6 +1248,16 @@ def generate_build_tree( ] env = {} if args.use_vcpkg: + # append VCPKG_INSTALL_OPTIONS + # + # VCPKG_INSTALL_OPTIONS is a CMake list. It must be joined by semicolons + # Therefore, if any of the option string contains a semicolon, it must be escaped + temp_cmake_args += [ + "-DVCPKG_INSTALL_OPTIONS={}".format( + ";".join(generate_vcpkg_install_options(Path(build_dir) / config, args)) + ) + ] + vcpkg_keep_env_vars = ["TRT_UPLOAD_AUTH_TOKEN"] if args.build_wasm: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index b304ccdb4c533..0410001d77d13 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -65,6 +65,12 @@ jobs: clean: true submodules: none + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: 'x64' + - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index e08de4be17574..586f7a2496eb4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,6 +83,9 @@ jobs: versionSpec: '3.12' addToPath: true architecture: $(buildArch) + - task: NodeTool@0 + inputs: + versionSpec: '22.x' - ${{if eq(parameters.WithCache, true)}}: - script: | diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index 1fbb4a06b3c2b..a9f1420946354 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -93,8 +93,29 @@ def add_copyright_header(f) -> None: ) +def add_build_type(f, build_type: str) -> None: + """ + Add build type to the triplet file. + + Args: + f (file object): The file object to write the build type. + build_type (str): The build type to add. Must be one of "Debug", "Release", "RelWithDebInfo", or "MinSizeRel". + """ + if build_type not in ["Debug", "Release", "RelWithDebInfo", "MinSizeRel"]: + raise ValueError( + f"Invalid build type: {build_type}. Must be one of 'Debug', 'Release', 'RelWithDebInfo', or 'MinSizeRel'." + ) + + if build_type != "Debug": + f.write( + """set(VCPKG_BUILD_TYPE release) +""" + ) + + def generate_triplet_for_android( build_dir: str, + configs: set[str], target_abi: str, enable_rtti: bool, enable_exception: bool, @@ -130,100 +151,102 @@ def generate_triplet_for_android( file_name = f"{target_abi}-android.cmake" - dest_path = Path(build_dir) / folder_name / file_name + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + + os.makedirs(dest_path.parent, exist_ok=True) - os.makedirs(dest_path.parent, exist_ok=True) + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) + # Set target architecture for Android + if target_abi == "arm-neon": + f.write("set(VCPKG_TARGET_ARCHITECTURE arm)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=armv7a-linux-androideabi)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=armeabi-v7a -DANDROID_ARM_NEON=ON -DCMAKE_ANDROID_ARM_NEON=ON)\n" + ) + elif target_abi == "arm64": + f.write("set(VCPKG_TARGET_ARCHITECTURE arm64)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=aarch64-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=arm64-v8a)\n") + elif target_abi == "x64": + f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=x86_64-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86_64)\n") + elif target_abi == "x86": + f.write("set(VCPKG_TARGET_ARCHITECTURE x86)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=i686-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86)\n") - # Set target architecture for Android - if target_abi == "arm-neon": - f.write("set(VCPKG_TARGET_ARCHITECTURE arm)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=armv7a-linux-androideabi)\n") f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=armeabi-v7a -DANDROID_ARM_NEON=ON -DCMAKE_ANDROID_ARM_NEON=ON)\n" + f"list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_USE_LEGACY_TOOLCHAIN_FILE=false -DANDROID_PLATFORM=android-{android_api_level} -DANDROID_MIN_SDK={android_api_level})\n" ) - elif target_abi == "arm64": - f.write("set(VCPKG_TARGET_ARCHITECTURE arm64)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=aarch64-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=arm64-v8a)\n") - elif target_abi == "x64": - f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=x86_64-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86_64)\n") - elif target_abi == "x86": - f.write("set(VCPKG_TARGET_ARCHITECTURE x86)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=i686-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86)\n") - f.write( - f"list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_USE_LEGACY_TOOLCHAIN_FILE=false -DANDROID_PLATFORM=android-{android_api_level} -DANDROID_MIN_SDK={android_api_level})\n" - ) - - # Set CRT linkage - # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). - # Valid options are dynamic and static. - crt_linkage = "static" - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + # Set CRT linkage + # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). + # Valid options are dynamic and static. + crt_linkage = "static" + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - # Set library linkage - # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. - # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. In our case, we prefer to use static libs. - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - if not enable_rtti: - f.write("set(CMAKE_ANDROID_RTTI OFF)\n") - if not enable_exception: - f.write("set(CMAKE_ANDROID_EXCEPTIONS OFF)\n") - if use_cpp_shared: - f.write("set(ANDROID_STL c++_shared)\n") + # Set library linkage + # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. + # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. In our case, we prefer to use static libs. + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + if not enable_rtti: + f.write("set(CMAKE_ANDROID_RTTI OFF)\n") + if not enable_exception: + f.write("set(CMAKE_ANDROID_EXCEPTIONS OFF)\n") + if use_cpp_shared: + f.write("set(ANDROID_STL c++_shared)\n") - ldflags = [] + ldflags = [] - cflags = ["-g", "-ffunction-sections", "-fdata-sections"] - cflags_release = ["-DNDEBUG", "-O3"] + cflags = ["-g", "-ffunction-sections", "-fdata-sections"] + cflags_release = ["-DNDEBUG", "-O3"] - if enable_asan: - cflags += ["-fsanitize=address"] - ldflags += ["-fsanitize=address"] + if enable_asan: + cflags += ["-fsanitize=address"] + ldflags += ["-fsanitize=address"] - ldflags.append("-g") + ldflags.append("-g") - cxxflags = cflags.copy() + cxxflags = cflags.copy() - if not enable_rtti: - cxxflags.append("-fno-rtti") + if not enable_rtti: + cxxflags.append("-fno-rtti") - if not enable_exception: - cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] + if not enable_exception: + cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - if cflags_release: - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + if cflags_release: + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - # Set target platform - # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Android)\n") - f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" - ) + # Set target platform + # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Android)\n") + f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" + ) - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") - add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") + add_build_type(f, config) + add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build -def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_level: int) -> None: +def generate_android_triplets(build_dir: str, configs: set[str], use_cpp_shared: bool, android_api_level: int) -> None: """ Generate triplet files for POSIX platforms (Linux, macOS, Android). @@ -240,6 +263,7 @@ def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_ for target_abi in target_abis: generate_triplet_for_android( build_dir, + configs, target_abi, enable_rtti, enable_exception, @@ -252,6 +276,7 @@ def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_ def generate_triplet_for_posix_platform( build_dir: str, + configs: set[str], os_name: str, enable_rtti: bool, enable_exception: bool, @@ -293,115 +318,118 @@ def generate_triplet_for_posix_platform( file_name = f"{target_abi}-{os_name}.cmake" - dest_path = Path(build_dir) / folder_name / file_name - - os.makedirs(dest_path.parent, exist_ok=True) - - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - - # Set target architecture based on `os_name` and `target_abi`. - # - # In most cases VCPKG itself can help automatically detect the target architecture, but sometimes it is not as what we want. The following code process the special cases. - if target_abi == "universal2": - # Assume the host machine is Intel based - f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") - else: - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - - # Set CRT linkage - # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). - # Valid options are dynamic and static. - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - - # Set library linkage - # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. - # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - - ldflags = [] - - if enable_binskim and os_name == "linux": - # BinSkim rule 3005: Enable stack clash protection - # This check ensures that stack clash protection is enabled. Each program running on a computer uses a special memory region called the stack. - # This memory region is special because it grows automatically when the program needs more stack memory. But if it grows too much and gets too close to another memory region, - # the program may confuse the stack with the other memory region. An attacker can exploit this confusion to overwrite the stack with the other memory region, or the other way around. - # Use the compiler flags '-fstack-clash-protection' to enable this. - # BinSkim rule BA3011: Enable BIND_NOW - # This check ensures that some relocation data is marked as read-only after the executable is loaded, and moved below the '.data' section in memory. - # This prevents them from being overwritten, which can redirect control flow. Use the compiler flags '-Wl,-z,now' to enable this. - ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now", "-Wl,-z,noexecstack"] - - cflags = ["-g", "-ffunction-sections", "-fdata-sections"] - cflags_release = ["-DNDEBUG", "-O3"] - - if enable_binskim: - cflags_release += ["-Wp,-D_FORTIFY_SOURCE=2", "-Wp,-D_GLIBCXX_ASSERTIONS", "-fstack-protector-strong"] - if target_abi == "x64": - cflags_release += ["-fstack-clash-protection", "-fcf-protection"] - - elif enable_asan: - cflags += ["-fsanitize=address"] - ldflags += ["-fsanitize=address"] - - ldflags.append("-g") - - if not enable_rtti: - cflags.append("-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") - - cxxflags = cflags.copy() - if os_name == "osx": - cxxflags += ["-fvisibility=hidden", "-fvisibility-inlines-hidden"] - if not enable_rtti: - cxxflags.append("-fno-rtti") - - if not enable_exception: - cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] - - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - - if cflags_release: - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - - # Set target platform - # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. - if os_name == "linux": - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Linux)\n") - else: - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Darwin)\n") - osx_abi = None - if target_abi == "x64": - osx_abi = "x86_64" - elif target_abi == "universal2": - osx_abi = "x86_64;arm64" + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + + os.makedirs(dest_path.parent, exist_ok=True) + + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + + # Set target architecture based on `os_name` and `target_abi`. + # + # In most cases VCPKG itself can help automatically detect the target architecture, but sometimes it is not as what we want. The following code process the special cases. + if target_abi == "universal2": + # Assume the host machine is Intel based + f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") else: - osx_abi = target_abi - f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') - if osx_deployment_target: - f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') - f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" - ) + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - if os_name == "osx": - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=20)\n") - else: - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") - add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build + # Set CRT linkage + # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). + # Valid options are dynamic and static. + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + + # Set library linkage + # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. + # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + + ldflags = [] + + if enable_binskim and os_name == "linux": + # BinSkim rule 3005: Enable stack clash protection + # This check ensures that stack clash protection is enabled. Each program running on a computer uses a special memory region called the stack. + # This memory region is special because it grows automatically when the program needs more stack memory. But if it grows too much and gets too close to another memory region, + # the program may confuse the stack with the other memory region. An attacker can exploit this confusion to overwrite the stack with the other memory region, or the other way around. + # Use the compiler flags '-fstack-clash-protection' to enable this. + # BinSkim rule BA3011: Enable BIND_NOW + # This check ensures that some relocation data is marked as read-only after the executable is loaded, and moved below the '.data' section in memory. + # This prevents them from being overwritten, which can redirect control flow. Use the compiler flags '-Wl,-z,now' to enable this. + ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now", "-Wl,-z,noexecstack"] + + cflags = ["-g", "-ffunction-sections", "-fdata-sections"] + cflags_release = ["-DNDEBUG", "-O3"] + + if enable_binskim: + cflags_release += ["-Wp,-D_FORTIFY_SOURCE=2", "-Wp,-D_GLIBCXX_ASSERTIONS", "-fstack-protector-strong"] + if target_abi == "x64": + cflags_release += ["-fstack-clash-protection", "-fcf-protection"] + + elif enable_asan: + cflags += ["-fsanitize=address"] + ldflags += ["-fsanitize=address"] + + ldflags.append("-g") + + if not enable_rtti: + cflags.append("-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") + + cxxflags = cflags.copy() + if os_name == "osx": + cxxflags += ["-fvisibility=hidden", "-fvisibility-inlines-hidden"] + if not enable_rtti: + cxxflags.append("-fno-rtti") + + if not enable_exception: + cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] + + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + + if cflags_release: + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + + # Set target platform + # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. + if os_name == "linux": + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Linux)\n") + else: + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Darwin)\n") + osx_abi = None + if target_abi == "x64": + osx_abi = "x86_64" + elif target_abi == "universal2": + osx_abi = "x86_64;arm64" + else: + osx_abi = target_abi + f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') + if osx_deployment_target: + f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') + f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" + ) + + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + if os_name == "osx": + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=20)\n") + else: + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") + add_build_type(f, config) + add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build def generate_vcpkg_triplets_for_emscripten( build_dir: str, + configs: set[str], emscripten_root: str, # Parameters defining the specific build configuration enable_rtti: bool, @@ -449,105 +477,108 @@ def generate_vcpkg_triplets_for_emscripten( for target_abi in ["wasm32", "wasm64"]: os_name = "emscripten" file_name = f"{target_abi}-{os_name}.cmake" - dest_path = Path(build_dir) / folder_name / file_name - os.makedirs(dest_path.parent, exist_ok=True) + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + os.makedirs(dest_path.parent, exist_ok=True) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - f.write(r""" + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + f.write(r""" set(VCPKG_CRT_LINKAGE dynamic) set(VCPKG_LIBRARY_LINKAGE static) set(VCPKG_CMAKE_SYSTEM_NAME Emscripten) """) - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - emscripten_root_path_cmake_path = emscripten_root.replace("\\", "/") - f.write(f'set(EMSCRIPTEN_ROOT_PATH "{emscripten_root_path_cmake_path}")\n') - - # Define the path to the intermediate toolchain file used by vcpkg for wasm - vcpkg_toolchain_file = (Path(build_dir) / "emsdk_vcpkg_toolchain.cmake").absolute() - vcpkg_toolchain_file_cmake_path = str(vcpkg_toolchain_file).replace("\\", "/") - f.write(f'set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "{vcpkg_toolchain_file_cmake_path}")\n') - - # --- Configure Flags based on Parameters --- - cflags_release = ["-DNDEBUG", "-O3", "-flto"] - ldflags = [] # Initialize linker flags list - # Base flags applicable to both C and C++ - base_flags = [ - "-ffunction-sections", - "-fdata-sections", - "-msimd128", - "-pthread", - "-Wno-pthreads-mem-growth", - ] - - # ASan (apply to Base, Linker) - if enable_asan: - asan_flag = "-fsanitize=address" - base_flags.append(asan_flag) - ldflags.append(asan_flag) # Add to linker flags - - # Wasm Exception Catching Runtime (-s flag, apply to Base and Linker flags) - exception_catching_flag = "" - if enable_wasm_exception_catching: - exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=0" - else: - exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=1" - - base_flags.append(exception_catching_flag) # Add to base C/C++ flags - ldflags.append(exception_catching_flag) # Add to linker flags - - # Wasm64 Memory (apply to Base, Linker) - if target_abi == "wasm64": - memory_flag = "-sMEMORY64" - base_flags.append(memory_flag) - ldflags.append(memory_flag) # Add to linker flags - - # --- C Flags --- - # VCPKG_C_FLAGS applies only base flags - f.write(f'set(VCPKG_C_FLAGS "{" ".join(base_flags)}")\n') - - # --- CXX Flags --- - # Start with base flags - cxxflags = list(base_flags) # Create a copy - - # C++ RTTI Compiler Flag - if not enable_rtti: - cxxflags.append("-fno-rtti") - - # C++ Exceptions Compiler Flag (Derived from enable_minimal_onnx_build) - if not cpp_exceptions_enabled: # i.e., if enable_minimal_onnx_build is True - cxxflags.append("-fno-exceptions") - # If cpp_exceptions_enabled=True, we assume -fexceptions is the default - # or handled by the Emscripten toolchain/CMake settings elsewhere. - - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - - # --- Linker Flags --- - # Apply Linker flags (now includes exception and memory flags explicitly) - if len(ldflags) >= 1: - f.write('set(VCPKG_LINKER_FLAGS "{}")\n'.format(" ".join(ldflags))) - - # --- Release / RelWithDebInfo Flags --- - # Combine base flags with release-specific flags - c_combined_release_flags = cflags_release + base_flags - cxx_combined_release_flags = cflags_release + cxxflags # Use the derived cxxflags - - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(c_combined_release_flags)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cxx_combined_release_flags)}")\n') - - f.write("set(VCPKG_LINKER_FLAGS_RELEASE -flto)\n") - - # --- Add Port Specific Configs --- - # Pass the derived C++ exception status and the original minimal build flag - add_port_configs( - f, - has_exception=cpp_exceptions_enabled, # Derived value - is_emscripten=True, - enable_minimal_build=enable_minimal_onnx_build, - ) # Original parameter - - -def generate_windows_triplets(build_dir: str, toolset_version: str) -> None: + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") + emscripten_root_path_cmake_path = emscripten_root.replace("\\", "/") + f.write(f'set(EMSCRIPTEN_ROOT_PATH "{emscripten_root_path_cmake_path}")\n') + + # Define the path to the intermediate toolchain file used by vcpkg for wasm + vcpkg_toolchain_file = (Path(build_dir) / "emsdk_vcpkg_toolchain.cmake").absolute() + vcpkg_toolchain_file_cmake_path = str(vcpkg_toolchain_file).replace("\\", "/") + f.write(f'set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "{vcpkg_toolchain_file_cmake_path}")\n') + + # --- Configure Flags based on Parameters --- + cflags_release = ["-DNDEBUG", "-O3", "-flto"] + ldflags = [] # Initialize linker flags list + # Base flags applicable to both C and C++ + base_flags = [ + "-ffunction-sections", + "-fdata-sections", + "-msimd128", + "-pthread", + "-Wno-pthreads-mem-growth", + ] + + # ASan (apply to Base, Linker) + if enable_asan: + asan_flag = "-fsanitize=address" + base_flags.append(asan_flag) + ldflags.append(asan_flag) # Add to linker flags + + # Wasm Exception Catching Runtime (-s flag, apply to Base and Linker flags) + exception_catching_flag = "" + if enable_wasm_exception_catching: + exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=0" + else: + exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=1" + + base_flags.append(exception_catching_flag) # Add to base C/C++ flags + ldflags.append(exception_catching_flag) # Add to linker flags + + # Wasm64 Memory (apply to Base, Linker) + if target_abi == "wasm64": + memory_flag = "-sMEMORY64" + base_flags.append(memory_flag) + ldflags.append(memory_flag) # Add to linker flags + + # --- C Flags --- + # VCPKG_C_FLAGS applies only base flags + f.write(f'set(VCPKG_C_FLAGS "{" ".join(base_flags)}")\n') + + # --- CXX Flags --- + # Start with base flags + cxxflags = list(base_flags) # Create a copy + + # C++ RTTI Compiler Flag + if not enable_rtti: + cxxflags.append("-fno-rtti") + + # C++ Exceptions Compiler Flag (Derived from enable_minimal_onnx_build) + if not cpp_exceptions_enabled: # i.e., if enable_minimal_onnx_build is True + cxxflags.append("-fno-exceptions") + # If cpp_exceptions_enabled=True, we assume -fexceptions is the default + # or handled by the Emscripten toolchain/CMake settings elsewhere. + + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + + # --- Linker Flags --- + # Apply Linker flags (now includes exception and memory flags explicitly) + if len(ldflags) >= 1: + f.write('set(VCPKG_LINKER_FLAGS "{}")\n'.format(" ".join(ldflags))) + + # --- Release / RelWithDebInfo Flags --- + # Combine base flags with release-specific flags + c_combined_release_flags = cflags_release + base_flags + cxx_combined_release_flags = cflags_release + cxxflags # Use the derived cxxflags + + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(c_combined_release_flags)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cxx_combined_release_flags)}")\n') + + f.write("set(VCPKG_LINKER_FLAGS_RELEASE -flto)\n") + + add_build_type(f, config) + + # --- Add Port Specific Configs --- + # Pass the derived C++ exception status and the original minimal build flag + add_port_configs( + f, + has_exception=cpp_exceptions_enabled, # Derived value + is_emscripten=True, + enable_minimal_build=enable_minimal_onnx_build, + ) # Original parameter + + +def generate_windows_triplets(build_dir: str, configs: set[str], toolset_version: str) -> None: """ Generate triplet files for Windows platforms. @@ -593,54 +624,56 @@ def generate_windows_triplets(build_dir: str, toolset_version: str) -> None: if crt_linkage == "dynamic": file_name_parts.append("md") file_name = "-".join(file_name_parts) + ".cmake" - dest_path = Path(build_dir) / folder_name / file_name - os.makedirs(dest_path.parent, exist_ok=True) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - if toolset_version: - f.write(f"set(VCPKG_PLATFORM_TOOLSET_VERSION {toolset_version})\n") - cflags = ["/MP", "/DWIN32", "/D_WINDOWS"] - if enable_binskim: - cflags += [ - "/DWINAPI_FAMILY=100", - "/DWINVER=0x0A00", - "/D_WIN32_WINNT=0x0A00", - "/DNTDDI_VERSION=0x0A000000", - ] - ldflags = [] - if enable_binskim: - cflags += ["/guard:cf", "/Qspectre", "/W3"] - ldflags = ["/profile", "/DYNAMICBASE"] - elif enable_asan: - cflags.append("/fsanitize=address") - cxxflags = cflags.copy() - cxxflags.append("/Zc:__cplusplus") - if enable_exception: - cxxflags.append("/EHsc") - # MSVC doesn't have a specific flag to disable exceptions like /EHs-c- - # but relies on _HAS_EXCEPTIONS=0 and potentially other flags managed by ORT's main CMake. - # Vcpkg doesn't directly control this via a simple triplet flag AFAIK. - # ORT's CMake handles this via CMAKE_CXX_FLAGS adjustment. - if not enable_rtti: - cxxflags += ["/GR-", "/we4541"] - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DCMAKE_CXX_STANDARD=17)\n" - ) - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - add_port_configs( - f, enable_exception, False, enable_minimal_build - ) # Pass enable_minimal_build - - -def generate_linux_triplets(build_dir: str) -> None: + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + os.makedirs(dest_path.parent, exist_ok=True) + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + if toolset_version: + f.write(f"set(VCPKG_PLATFORM_TOOLSET_VERSION {toolset_version})\n") + cflags = ["/MP", "/DWIN32", "/D_WINDOWS"] + if enable_binskim: + cflags += [ + "/DWINAPI_FAMILY=100", + "/DWINVER=0x0A00", + "/D_WIN32_WINNT=0x0A00", + "/DNTDDI_VERSION=0x0A000000", + ] + ldflags = [] + if enable_binskim: + cflags += ["/guard:cf", "/Qspectre", "/W3"] + ldflags = ["/profile", "/DYNAMICBASE"] + elif enable_asan: + cflags.append("/fsanitize=address") + cxxflags = cflags.copy() + cxxflags.append("/Zc:__cplusplus") + if enable_exception: + cxxflags.append("/EHsc") + # MSVC doesn't have a specific flag to disable exceptions like /EHs-c- + # but relies on _HAS_EXCEPTIONS=0 and potentially other flags managed by ORT's main CMake. + # Vcpkg doesn't directly control this via a simple triplet flag AFAIK. + # ORT's CMake handles this via CMAKE_CXX_FLAGS adjustment. + if not enable_rtti: + cxxflags += ["/GR-", "/we4541"] + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DCMAKE_CXX_STANDARD=17)\n" + ) + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + add_build_type(f, config) + add_port_configs( + f, enable_exception, False, enable_minimal_build + ) # Pass enable_minimal_build + + +def generate_linux_triplets(build_dir: str, configs: set[str]) -> None: """ Generate triplet files for Linux platforms. @@ -660,6 +693,7 @@ def generate_linux_triplets(build_dir: str) -> None: for target_abi in target_abis: generate_triplet_for_posix_platform( build_dir, + configs, "linux", enable_rtti, enable_exception, @@ -672,7 +706,7 @@ def generate_linux_triplets(build_dir: str) -> None: ) -def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: +def generate_macos_triplets(build_dir: str, configs: set[str], osx_deployment_target: str) -> None: """ Generate triplet files for macOS platforms. @@ -694,6 +728,7 @@ def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: for target_abi in target_abis: generate_triplet_for_posix_platform( build_dir, + configs, "osx", enable_rtti, enable_exception,