From 0e66c18db1f919d8f549c6ea30a4909e9c5e48e5 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Thu, 18 Sep 2025 15:42:36 +0800 Subject: [PATCH 1/3] Try to support scalar for input --- .../core/providers/utils/ort_graph_to_proto.h | 15 ++++++++++----- .../onnxruntime/core/session/onnxruntime_c_api.h | 11 +++++++++++ .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 7 +++++++ .../core/framework/tensor_type_and_shape.cc | 7 +++++++ .../core/framework/tensor_type_and_shape.h | 1 + onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 1 + 8 files changed, 39 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index e2b2aff2011fe..9119554d4a779 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -225,7 +225,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape); static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); @@ -390,9 +391,10 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph, std::vector initializer_dims; std::vector initializer_sym_dims; ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + bool has_shape = false; ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, initializer_elem_type, initializer_dims, - initializer_sym_dims)); + initializer_sym_dims, has_shape)); onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); tensor_proto->set_name(initializer_name); @@ -493,7 +495,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape) { try { Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); ONNXType ort_onnx_type = ort_type_info.GetONNXType(); @@ -505,6 +508,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, size_t num_dims = ort_type_shape.GetDimensionsCount(); std::vector ort_dims = ort_type_shape.GetShape(); + has_shape = ort_type_shape.GetHasShape(); elem_type = ort_elem_type; dims = std::move(ort_dims); @@ -531,10 +535,11 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, std::vector ort_dims; std::vector ort_dim_syms; ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + bool has_shape = false; // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); + ort_elem_type, ort_dims, ort_dim_syms, has_shape)); value_info_proto.set_name(ort_value_info.GetName()); @@ -543,7 +548,7 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks // like a scalar value. - if (!ort_dims.empty()) { + if (!ort_dims.empty() || has_shape) { onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 81caf5069bb6e..408d5bb165a19 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6580,6 +6580,17 @@ struct OrtApi { _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); /// @} + /// \name OrtTensorTypeAndShapeInfo + /// @{ + + /** \brief Get the attribute `has_shape` from ::OrtTensorTypeAndShapeInfo object + * + * \param[out] out Returns bool + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out); + /// @} }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4a8c67e2215ec..ce4f3d867a245 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1768,6 +1768,7 @@ struct TensorTypeAndShapeInfoImpl : Base { void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions std::vector GetSymbolicDimensions() const; + bool GetHasShape() const; ///< Wraps OrtApi::GetHasShape std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 9c42bf34b5b0f..9349e84e55474 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1982,6 +1982,13 @@ inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { return static_cast(out); } +template +inline bool TensorTypeAndShapeInfoImpl::GetHasShape() const { + bool out; + ThrowOnError(GetApi().GetHasShape(this->p_, &out)); + return static_cast(out); +} + template inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { size_t out; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index bef8df51f6d03..3347a0357ca6d 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -102,6 +102,12 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::GetHasShape, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ bool* out) { + *out = info->has_shape; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { @@ -228,6 +234,7 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS if (dim_params != nullptr) { type_and_shape->dim_params = *dim_params; + type_and_shape->has_shape = true; } else { type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), ""); } diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 4bc1f46c00132..bb72e1ff35fe4 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -24,6 +24,7 @@ struct OrtTensorTypeAndShapeInfo { // dim_param values. empty string if dim_value or no dim_param was specified. // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs std::vector dim_params; + bool has_shape = false; OrtTensorTypeAndShapeInfo(); ~OrtTensorTypeAndShapeInfo(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d0fe6291c2e03..91e9eeca12ad7 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4228,6 +4228,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetModelMetadata, &OrtApis::GetModelCompatibilityForEpDevices, &OrtApis::CreateExternalInitializerInfo, + &OrtApis::GetHasShape, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 78616c7b3973e..2ce00719285c6 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -150,6 +150,7 @@ ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* in ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); +ORT_API_STATUS_IMPL(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out); ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); From d9007558b8a278688dd2cdce55e7ef81e29e450a Mon Sep 17 00:00:00 2001 From: LiangGao Date: Tue, 23 Sep 2025 11:09:01 +0800 Subject: [PATCH 2/3] Change has_shape to parameter --- .../core/framework/onnxruntime_typeinfo.cc | 10 ++++---- .../core/framework/tensor_type_and_shape.cc | 24 +++++++++++-------- .../core/framework/tensor_type_and_shape.h | 9 ++++--- onnxruntime/core/session/custom_ops.cc | 2 +- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 1c446840b7938..eedacec5b6de8 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -170,7 +170,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { const Tensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, true); return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); } return MakePtr(ONNX_TYPE_TENSOR); @@ -181,7 +181,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { const SparseTensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type, true); return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape)); } return MakePtr(ONNX_TYPE_SPARSETENSOR); @@ -195,7 +195,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType."); TensorShape void_shape = {}; - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type, false); auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); auto sequence_type_info = std::make_unique(std::move(type_info)); return MakePtr(std::move(sequence_type_info)); @@ -303,9 +303,9 @@ std::unique_ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::Ty assert(false); } } - type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input, true); } else { - type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input, false); } result = MakePtr(ten_type, std::move(type_shape)); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 3347a0357ca6d..a41b62cf948a1 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -227,14 +227,15 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( ONNXTensorElementDataType type, onnxruntime::TensorShape shape, - const std::vector* dim_params) { + const std::vector* dim_params, + bool has_shape) { auto type_and_shape = std::make_unique(); type_and_shape->type = type; type_and_shape->shape = std::move(shape); + type_and_shape->has_shape = has_shape; if (dim_params != nullptr) { type_and_shape->dim_params = *dim_params; - type_and_shape->has_shape = true; } else { type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), ""); } @@ -244,18 +245,20 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( onnxruntime::TensorShape shape, - const onnxruntime::DataTypeImpl& tensor_data_type) { + const onnxruntime::DataTypeImpl& tensor_data_type, + bool has_shape) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr); + return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr, has_shape); } std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( onnxruntime::TensorShape shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto) { + const ONNX_NAMESPACE::TypeProto& type_proto, + bool has_shape) { auto value_case = type_proto.value_case(); assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); @@ -266,7 +269,8 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params); + + return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params, has_shape); } ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, @@ -283,14 +287,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, const Tensor& tensor = v->Get(); shape = &tensor.Shape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true); *out = ptr.release(); } else { #if !defined(DISABLE_SPARSE_TENSORS) const SparseTensor& tensor = v->Get(); shape = &tensor.DenseShape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true); *out = ptr.release(); #else ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); @@ -309,7 +313,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa #if !defined(DISABLE_SPARSE_TENSORS) const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v); const auto& values = sparse_tensor.Values(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType()); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType(), true); *out = ptr.release(); return nullptr; #else @@ -351,7 +355,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType()); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), true); *out = ptr.release(); return nullptr; #else diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index bb72e1ff35fe4..024c3805d316b 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -33,16 +33,19 @@ struct OrtTensorTypeAndShapeInfo { static std::unique_ptr GetTensorShapeAndTypeHelper( ONNXTensorElementDataType type, onnxruntime::TensorShape shape, - const std::vector* dim_params); + const std::vector* dim_params, + bool has_shape); static std::unique_ptr GetTensorShapeAndType( onnxruntime::TensorShape shape, - const onnxruntime::DataTypeImpl& tensor_data_type); + const onnxruntime::DataTypeImpl& tensor_data_type, + bool has_shape); static std::unique_ptr GetTensorShapeAndType( onnxruntime::TensorShape shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto&); + const ONNX_NAMESPACE::TypeProto&, + bool has_shape); // We provide Clone() here to satisfy the existing coding pattern // as we need copies made on the heap even though we achieve that diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 2a898a2b0bf9f..5c56eee2ca3b5 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -80,7 +80,7 @@ struct OrtShapeInferContext { auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); auto symbolic_dims = GetSymbolicDims(shape_proto); input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims, type_proto.has_shape()).release()); } } From df4adc70007e5158e935338aaea7698fecf5abe1 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Wed, 24 Sep 2025 09:36:00 +0800 Subject: [PATCH 3/3] Change code based on code review --- onnxruntime/core/session/custom_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 5c56eee2ca3b5..b8d66281bf60b 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -80,7 +80,7 @@ struct OrtShapeInferContext { auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); auto symbolic_dims = GetSymbolicDims(shape_proto); input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims, type_proto.has_shape()).release()); + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims, onnxruntime::utils::HasShape(type_proto)).release()); } }