Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector<int64_t>& dims,
/*out*/ std::vector<std::string>& symbolic_dims);
/*out*/ std::vector<std::string>& symbolic_dims,
/*out*/ bool& has_shape);
static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto);
static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto);

Expand Down Expand Up @@ -390,9 +391,10 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph,
std::vector<int64_t> initializer_dims;
std::vector<std::string> initializer_sym_dims;
ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
bool has_shape = false;
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false,
initializer_elem_type, initializer_dims,
initializer_sym_dims));
initializer_sym_dims, has_shape));

onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
tensor_proto->set_name(initializer_name);
Expand Down Expand Up @@ -493,7 +495,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector<int64_t>& dims,
/*out*/ std::vector<std::string>& symbolic_dims) {
/*out*/ std::vector<std::string>& symbolic_dims,
/*out*/ bool& has_shape) {
try {
Ort::ConstTypeInfo ort_type_info = vi.TypeInfo();
ONNXType ort_onnx_type = ort_type_info.GetONNXType();
Expand All @@ -505,6 +508,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
size_t num_dims = ort_type_shape.GetDimensionsCount();
std::vector<int64_t> ort_dims = ort_type_shape.GetShape();

has_shape = ort_type_shape.GetHasShape();
elem_type = ort_elem_type;
dims = std::move(ort_dims);

Expand All @@ -531,10 +535,11 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info,
std::vector<int64_t> ort_dims;
std::vector<std::string> ort_dim_syms;
ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
bool has_shape = false;

// We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later.
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true,
ort_elem_type, ort_dims, ort_dim_syms));
ort_elem_type, ort_dims, ort_dim_syms, has_shape));

value_info_proto.set_name(ort_value_info.GetName());

Expand All @@ -543,7 +548,7 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info,

// If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks
// like a scalar value.
if (!ort_dims.empty()) {
if (!ort_dims.empty() || has_shape) {
onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape();

for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) {
Expand Down
11 changes: 11 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6580,6 +6580,17 @@ struct OrtApi {
_In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out);

/// @}
/// \name OrtTensorTypeAndShapeInfo
/// @{

/** \brief Get the attribute `has_shape` from ::OrtTensorTypeAndShapeInfo object
*
* \param[out] out Returns bool
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out);
/// @}
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,7 @@ struct TensorTypeAndShapeInfoImpl : Base<T> {
void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
std::vector<const char*> GetSymbolicDimensions() const;

bool GetHasShape() const; ///< Wraps OrtApi::GetHasShape
std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
};

Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,13 @@ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
return static_cast<size_t>(out);
}

template <typename T>
inline bool TensorTypeAndShapeInfoImpl<T>::GetHasShape() const {
bool out;
ThrowOnError(GetApi().GetHasShape(this->p_, &out));
return static_cast<bool>(out);
}

template <typename T>
inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
size_t out;
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/framework/onnxruntime_typeinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
const Tensor& tensor = value.Get<onnxruntime::Tensor>();
const auto* tensor_data_type = tensor.DataType();
if (tensor_data_type != nullptr) {
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type);
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, true);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decision to put true or false in this file is not clear to me. The type_protos here are not from the original protobuf, but rather from generic type registrations. How do we know that this particular tensor had real shape in the original protobuf since all we have here is the OrtValue which obviously has a real shape.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set this value based on http://github.com/intel/onnxruntime/blob/ovep-develop/include/onnxruntime/core/framework/tensor_shape.h#L109 , because the tensor doesn't have an API to indicate this is a scalar. Do you know an API can do this?

And I think the purpose of adding the has_shape is to fix the information lost at model conversion stage. If we already have the shape, it is not a dynamic rank case. I think this can get the correct shapes when dumping models.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a mix of two different things here.

Tensor and TensorShape classes are runtime classes. They will always have a shape that will be inferred, otherwise we cannot run.

IsScalar() is the method on TensorShape that will tell you if it is a scalar or not based on the shape that was inferred or contained (not exposed via public API) since it can also be found out from the shape vector returned.

The public APIs that are you are dealing with have to do with specific concrete OrtValues that contain tensors that are a result of either creating a feeding input to the model or receiving an output. So they do not have anything to do with the protobuf of the original model. The same would apply to intermediate values, which is a new path that appeared with EP plug-ins.

They are not direct counterparts of what is found in protobuf which is a metadata and this is what you are probably trying to get.

It would be helpful to see an example of the model that you are trying to deal with and what exactly you are trying to achieve (serialize the model?)

As far as the model input/output goes, Session API can return ValueInfo that contains TypeInfo. That can tell you the dimension count but does not tell you if it has the shape, perhaps it is an omission.

Copy link

@adrianlizarraga adrianlizarraga Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, an OrtValue is already known to have a shape, so passing true here makes sense to me.

It would be helpful to see an example of the model that you are trying to deal with and what exactly you are trying to achieve (serialize the model?)

The onnx model file has been provided in this comment: #816 (comment)

image

The plugin EP gets the model as a OrtGraph via the new APIs. The EP then converts the OrtGraph to protobuf so that it can be further processed by openvino compiler, which currently only supports protobuf models (not OrtGraph). Please correct me if I'm wrong @sgbihu

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OV only handle the protobuf models for now. I also pasted a prototxt at the PR description, and the shape information is lost. That means a scalar to a dynamic rank tensor.

return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape));
}
return MakePtr(ONNX_TYPE_TENSOR);
Expand All @@ -181,7 +181,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
const SparseTensor& tensor = value.Get<onnxruntime::SparseTensor>();
const auto* tensor_data_type = tensor.DataType();
if (tensor_data_type != nullptr) {
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type);
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type, true);
return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape));
}
return MakePtr(ONNX_TYPE_SPARSETENSOR);
Expand All @@ -195,7 +195,7 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromOrtValue(const OrtValue& value) {
ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType.");

TensorShape void_shape = {};
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type);
auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type, false);
auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape));
auto sequence_type_info = std::make_unique<OrtSequenceTypeInfo>(std::move(type_info));
return MakePtr(std::move(sequence_type_info));
Expand Down Expand Up @@ -303,9 +303,9 @@ std::unique_ptr<OrtTypeInfo> OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::Ty
assert(false);
}
}
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input);
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input, true);
} else {
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input);
type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input, false);
}

result = MakePtr(ten_type, std::move(type_shape));
Expand Down
29 changes: 20 additions & 9 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::GetHasShape, _In_ const struct OrtTensorTypeAndShapeInfo* info,
_Out_ bool* out) {
*out = info->has_shape;
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions,
_In_ struct OrtTensorTypeAndShapeInfo* info,
_In_ const char** names, _In_ size_t dim_params_length) {
Expand Down Expand Up @@ -221,10 +227,12 @@
std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(
ONNXTensorElementDataType type,
onnxruntime::TensorShape shape,
const std::vector<std::string>* dim_params) {
const std::vector<std::string>* dim_params,
bool has_shape) {
auto type_and_shape = std::make_unique<OrtTensorTypeAndShapeInfo>();
type_and_shape->type = type;
type_and_shape->shape = std::move(shape);
type_and_shape->has_shape = has_shape;

if (dim_params != nullptr) {
type_and_shape->dim_params = *dim_params;
Expand All @@ -237,18 +245,20 @@

std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(
onnxruntime::TensorShape shape,
const onnxruntime::DataTypeImpl& tensor_data_type) {
const onnxruntime::DataTypeImpl& tensor_data_type,
bool has_shape) {
ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
ORT_NOT_IMPLEMENTED("Tensor type is undefined");
}
return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr);
return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr, has_shape);
}

std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(
onnxruntime::TensorShape shape,
const std::vector<std::string>* dim_params,
const ONNX_NAMESPACE::TypeProto& type_proto) {
const ONNX_NAMESPACE::TypeProto& type_proto,
bool has_shape) {
auto value_case = type_proto.value_case();
assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType ||
value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
Expand All @@ -259,7 +269,8 @@
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
ORT_NOT_IMPLEMENTED("Tensor type is undefined");
}
return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params);

return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params, has_shape);

Check warning on line 273 in onnxruntime/core/framework/tensor_type_and_shape.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/tensor_type_and_shape.cc:273: Add #include <utility> for move [build/include_what_you_use] [4]
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
Expand All @@ -276,14 +287,14 @@
const Tensor& tensor = v->Get<onnxruntime::Tensor>();
shape = &tensor.Shape();
data_type = tensor.DataType();
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type);
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true);
*out = ptr.release();
} else {
#if !defined(DISABLE_SPARSE_TENSORS)
const SparseTensor& tensor = v->Get<onnxruntime::SparseTensor>();
shape = &tensor.DenseShape();
data_type = tensor.DataType();
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type);
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type, true);
*out = ptr.release();
#else
ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build.");
Expand All @@ -302,7 +313,7 @@
#if !defined(DISABLE_SPARSE_TENSORS)
const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
const auto& values = sparse_tensor.Values();
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType());
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType(), true);
*out = ptr.release();
return nullptr;
#else
Expand Down Expand Up @@ -344,7 +355,7 @@
API_IMPL_BEGIN
#if !defined(DISABLE_SPARSE_TENSORS)
const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format);
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType());
auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), true);
*out = ptr.release();
return nullptr;
#else
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/core/framework/tensor_type_and_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct OrtTensorTypeAndShapeInfo {
// dim_param values. empty string if dim_value or no dim_param was specified.
// one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs
std::vector<std::string> dim_params;
bool has_shape = false;

OrtTensorTypeAndShapeInfo();
~OrtTensorTypeAndShapeInfo();
Expand All @@ -32,16 +33,19 @@ struct OrtTensorTypeAndShapeInfo {
static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndTypeHelper(
ONNXTensorElementDataType type,
onnxruntime::TensorShape shape,
const std::vector<std::string>* dim_params);
const std::vector<std::string>* dim_params,
bool has_shape);

static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndType(
onnxruntime::TensorShape shape,
const onnxruntime::DataTypeImpl& tensor_data_type);
const onnxruntime::DataTypeImpl& tensor_data_type,
bool has_shape);

static std::unique_ptr<OrtTensorTypeAndShapeInfo> GetTensorShapeAndType(
onnxruntime::TensorShape shape,
const std::vector<std::string>* dim_params,
const ONNX_NAMESPACE::TypeProto&);
const ONNX_NAMESPACE::TypeProto&,
bool has_shape);

// We provide Clone() here to satisfy the existing coding pattern
// as we need copies made on the heap even though we achieve that
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct OrtShapeInferContext {
auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto);
auto symbolic_dims = GetSymbolicDims(shape_proto);
input_type_shapes_.emplace_back(
OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release());
OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims, onnxruntime::utils::HasShape(type_proto)).release());
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4228,6 +4228,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Graph_GetModelMetadata,
&OrtApis::GetModelCompatibilityForEpDevices,
&OrtApis::CreateExternalInitializerInfo,
&OrtApis::GetHasShape,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* in
ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info,
_Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length);
ORT_API_STATUS_IMPL(GetHasShape, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ bool* out);
ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out);
Expand Down
Loading