From ae93f2e30d2e31afbe6022978871067691aa1dfb Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 28 Mar 2023 17:23:34 -0700 Subject: [PATCH 1/9] rebase off main --- .../InferenceSession.shared.cs | 381 ++++++++++++++++-- .../NativeMethods.shared.cs | 64 +++ .../NativeOnnxValueHelper.shared.cs | 22 +- .../OrtValue.shared.cs | 1 + .../Tensors/Tensor.shared.cs | 24 +- .../InferenceTest.cs | 67 ++- .../TestDataLoader.cs | 141 +++---- .../InferenceTest.netcore.cs | 167 ++++---- .../Program.cs | 1 - .../core/session/onnxruntime_c_api.h | 70 +++- .../core/session/onnxruntime_cxx_api.h | 14 + .../core/session/onnxruntime_cxx_inline.h | 22 +- .../framework/onnxruntime_map_type_info.cc | 45 +-- .../framework/onnxruntime_map_type_info.h | 13 +- .../onnxruntime_optional_type_info.cc | 40 ++ .../onnxruntime_optional_type_info.h | 28 ++ .../onnxruntime_sequence_type_info.cc | 41 +- .../onnxruntime_sequence_type_info.h | 17 +- .../core/framework/onnxruntime_typeinfo.cc | 310 +++++++------- .../core/framework/onnxruntime_typeinfo.h | 56 ++- .../core/framework/tensor_type_and_shape.cc | 85 ++-- .../core/framework/tensor_type_and_shape.h | 31 +- onnxruntime/core/session/custom_ops.cc | 8 +- onnxruntime/core/session/onnxruntime_c_api.cc | 6 +- onnxruntime/core/session/ort_apis.h | 6 + winml/adapter/winml_adapter_model.cpp | 20 +- 26 files changed, 1133 insertions(+), 547 deletions(-) create mode 100644 onnxruntime/core/framework/onnxruntime_optional_type_info.cc create mode 100644 onnxruntime/core/framework/onnxruntime_optional_type_info.h diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 98e1833aed8a..91366db5a245 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -773,8 +773,8 @@ public ulong ProfilingStartTimeNs if (prepackedWeightsContainer == null) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), - options.Handle, out session)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), + options.Handle, out session)); } else @@ -979,18 +979,109 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out valType)); valueType = (OnnxValueType)valType; } - if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR) + + switch (valueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + case OnnxValueType.ONNX_TYPE_SPARSETENSOR: + return GetTensorNodeMetadata(valueType, typeInfo); + case OnnxValueType.ONNX_TYPE_SEQUENCE: + return GetSequenceMetadataFromTypeInfo(typeInfo); + case OnnxValueType.ONNX_TYPE_MAP: + return GetMapMetadataFromTypeInfo(typeInfo); + case OnnxValueType.ONNX_TYPE_OPTIONAL: + return GetOptionalMetadataFromTypeInfo(typeInfo); + } + + throw new NotImplementedException("Value type not supported in this code"); + } + + internal static NodeMetadata GetSequenceMetadataFromTypeInfo(IntPtr typeInfo) + { + IntPtr sequenceTypeInfo; + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToSequenceTypeInfo(typeInfo, out sequenceTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (sequenceTypeInfo == IntPtr.Zero) { - return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue)); + throw new InvalidOperationException("TypeInfo cast to SequenceTypeInfo failed. The object does not represent a sequence"); } - // This should not be released + IntPtr elementType; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetSequenceElementType(sequenceTypeInfo, out elementType)); + try + { + var elementMeta = GetMetadataFromTypeInfo(elementType); + var seqMeta = new SequenceMetadata(elementMeta); + return new NodeMetadata(seqMeta); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(elementType); + } + } + + internal static NodeMetadata GetMapMetadataFromTypeInfo(IntPtr typeInfo) + { + IntPtr mapTypeInfo; + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToMapTypeInfo(typeInfo, out mapTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (mapTypeInfo == IntPtr.Zero) + { + throw new InvalidOperationException("TypeInfo cast to MapTypeInfo failed. The object does not represent a map"); + } + + IntPtr keyType; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapKeyType(mapTypeInfo, out keyType)); + + IntPtr valueTypeInfo; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMapValueType(mapTypeInfo, out valueTypeInfo)); + try + { + var valueMetadata = GetMetadataFromTypeInfo(valueTypeInfo); + var mapMeta = new MapMetadata((TensorElementType)keyType, valueMetadata); + return new NodeMetadata(mapMeta); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(valueTypeInfo); + } + } + + internal static NodeMetadata GetOptionalMetadataFromTypeInfo(IntPtr typeInfo) + { + // This should not be destroyed + IntPtr optTypeInfo; + NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToOptionalTypeInfo(typeInfo, out optTypeInfo)); + // Casts API are broken. Always return success, but may return null for the result. + if (optTypeInfo == IntPtr.Zero) + { + throw new InvalidOperationException("TypeInfo cast to OptionalTypeInfo failed. The object does not represent a optional value"); + } + + IntPtr elementTypeInfo; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOptionalContainedTypeInfo(optTypeInfo, out elementTypeInfo)); + try + { + var elementMetadata = GetMetadataFromTypeInfo(elementTypeInfo); + var optMetadata = new OptionalMetadata(elementMetadata); + return new NodeMetadata(optMetadata); + } + finally + { + NativeMethods.OrtReleaseTypeInfo(elementTypeInfo); + } + } + + internal static NodeMetadata GetTensorNodeMetadata(OnnxValueType valueType, IntPtr typeInfo) + { + // Fetch tensor type and shape from the TypeInfo IntPtr tensorInfo; NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out tensorInfo)); //(IntPtr)(int)(uint) - // Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo* - + // Casts API are broken. Always return success, but may return null for the result. if (tensorInfo == IntPtr.Zero) - return null; + { + throw new InvalidOperationException("TypeInfo cast to TensorTypeInfo failed. The object does not represent a tensor"); + } TensorElementType type; { @@ -999,14 +1090,6 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) type = (TensorElementType)el_type; } - Type dotnetType = null; - int width = 0; - if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width)) - { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, - "Unable to query type information for data type: " + type.ToString()); - } - UIntPtr numDimensions; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions)); @@ -1028,7 +1111,8 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) symbolicDimensions[i] = NativeOnnxValueHelper.StringFromNativeUtf8(dimensionNamePtrs[i]); } - return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType); + var tensorTypeAndShape = new TensorTypeAndShape(type, intDimensions, symbolicDimensions); + return new NodeMetadata(valueType, tensorTypeAndShape); } /// @@ -1105,23 +1189,27 @@ protected virtual void Dispose(bool disposing) /// - /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes + /// Represents tensor element type and its shapes /// - public class NodeMetadata + public class TensorTypeAndShape { - internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type) + internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, string[] symbolicDimensions) { - OnnxValueType = onnxValueType; + ElementTypeInfo = TensorBase.GetElementTypeInfo(elementType); + if (ElementTypeInfo == null) + { + throw new ArgumentException("Unregistered TensorElementType value of: " + elementType.ToString()); + } + ElementDataType = elementType; Dimensions = dimensions; SymbolicDimensions = symbolicDimensions; - ElementType = type; } /// - /// Type value of the node + /// Tensor Element type /// - /// A value of OnnxValueType enum - public OnnxValueType OnnxValueType { get; } + /// TensorElementType enum + public TensorElementType ElementDataType { get; } /// /// Shape @@ -1136,10 +1224,247 @@ internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] sy public string[] SymbolicDimensions { get; } /// - /// .NET type that corresponds to this Node. + /// Tensor element metadata + /// + public TensorElementTypeInfo ElementTypeInfo { get; } + } + + /// + /// Represents sequnce metdata + /// + public class SequenceMetadata + { + /// + /// __ctor + /// + /// + internal SequenceMetadata(NodeMetadata elementData) + { + ElementMeta = elementData; + } + /// + /// Element Metatada, recursive definition with a Tensor being a base case + /// may contain maps, tensors and other sequences + /// + public NodeMetadata ElementMeta { get; } + } + + /// + /// The class contains metadata for an optional input/output + /// + public class OptionalMetadata + { + /// + /// __ctor + /// + /// + internal OptionalMetadata(NodeMetadata elementData) + { + ElementMeta = elementData; + } + + /// + /// Element Metatada, recursive definition with a Tensor being a base case + /// may contain maps, tensors and sequences + /// + public NodeMetadata ElementMeta { get; } + } + + /// + /// Represents Map MetaData. + /// Key is always a tensor denoted by an element type + /// with value type being a recursive structure that may + /// contain other maps, sequences or tensors. + /// + public class MapMetadata + { + internal MapMetadata(TensorElementType keyDataType, NodeMetadata valueMetadata) + { + KeyDataType = keyDataType; + ValueMetadata = valueMetadata; + } + + /// + /// Key tensor data type + /// + /// A value of TensorElementType enum + public TensorElementType KeyDataType { get; } + + /// + /// Value metadata + /// + /// /// Instance of Nodemetadata for the value of the map + public NodeMetadata ValueMetadata { get; } + } + + /// + /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes + /// + public class NodeMetadata + { + private readonly Object _metadata; + /// + /// Constructs NodeMetadata for tensor + /// + /// either ONNX_TYPE_TENSOR or ONNX_TYPE_SPARSETENSOR + /// Tensor type and shape information + internal NodeMetadata(OnnxValueType onnxValueType, TensorTypeAndShape typeAndShape) + { + OnnxValueType = onnxValueType; + CheckTensor(); + _metadata = typeAndShape; + } + + /// + /// __ctor for map metadata + /// + /// + internal NodeMetadata(MapMetadata mapMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_MAP; + _metadata = mapMetadata; + } + + /// + /// __ctor for sequence metadata + /// + /// + internal NodeMetadata(SequenceMetadata sequenceMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_SEQUENCE; + _metadata = sequenceMetadata; + } + + /// + /// __ctor + /// + /// + internal NodeMetadata(OptionalMetadata optMetadata) + { + OnnxValueType = OnnxValueType.ONNX_TYPE_OPTIONAL; + _metadata = optMetadata; + } + + private void CheckTensor() + { + if (!IsTensor) + { + throw new InvalidOperationException("OnnxValueType must either be a tensor or sparse tensor"); + } + } + + /// + /// Retrieves MapMetadata, valid only if this node represents a Map. + /// + /// + /// when the instance does not contain map metadata + public MapMetadata AsMapMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new InvalidOperationException("Instance does not contain Map metadata"); + } + return _metadata as MapMetadata; + } + + /// + /// Retrieves SequenceMetadata, valid only if this node represents a Sequence + /// + /// + /// when the instance does not contain sequence metadata + public SequenceMetadata AsSequenceMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_SEQUENCE) + { + throw new InvalidOperationException("Instance does not contain Sequence metadata"); + } + return _metadata as SequenceMetadata; + } + + /// + /// Retrieves Optional type metadata, valid if this node is optional + /// Optional metadata is nothing more than just a container for all the usual + /// element types. + /// + /// + /// + public OptionalMetadata AsOptionalMetadata() + { + if (OnnxValueType != OnnxValueType.ONNX_TYPE_OPTIONAL) + { + throw new InvalidOperationException("Instance does not contain Optional metadata"); + } + return _metadata as OptionalMetadata; + } + + /// + /// Type value of the node + /// + /// A value of OnnxValueType enum + public OnnxValueType OnnxValueType { get; } + + /// + /// Tensor shape valid only if this is a Tensor. + /// Preserved for API compatibility + /// + /// Array of dimensions + public int[] Dimensions + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).Dimensions; + } + } + + /// + /// Symbolic dimensions valid only if this is a Tensor. + /// Preserved for API compatibility + /// + /// Array of symbolic dimensions if present. + public string[] SymbolicDimensions + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).SymbolicDimensions; + } + } + + /// + /// .NET type that corresponds to the primitive Tensor data type. + /// Valid only if this is a Tensor. /// /// System.Type - public System.Type ElementType { get; } + public System.Type ElementType + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementTypeInfo.TensorType; + } + } + + /// + /// Tensor Element Type. Valid if tensor + /// + public TensorElementType ElementDataType + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementDataType; + } + } + + public bool IsString + { + get + { + CheckTensor(); + return (_metadata as TensorTypeAndShape).ElementTypeInfo.IsString; + } + } /// /// Whether it is a Tensor @@ -1149,7 +1474,7 @@ public bool IsTensor { get { - return true; // currently only Tensor nodes are supported + return (OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) || (OnnxValueType == OnnxValueType.ONNX_TYPE_SPARSETENSOR); } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 57ae4c290531..fcba15a7b280 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -261,6 +261,25 @@ public struct OrtApi public IntPtr ReleaseCANNProviderOptions; public IntPtr MemoryInfoGetDeviceType; public IntPtr UpdateEnvWithCustomLogLevel; + public IntPtr SetGlobalIntraOpThreadAffinity; + public IntPtr RegisterCustomOpsLibrary_V2; + public IntPtr RegisterCustomOpsUsingFunction; + public IntPtr KernelInfo_GetInputCount; + public IntPtr KernelInfo_GetOutputCount; + public IntPtr KernelInfo_GetInputName; + public IntPtr KernelInfo_GetOutputName; + public IntPtr KernelInfo_GetInputTypeInfo; + public IntPtr KernelInfo_GetOutputTypeInfo; + public IntPtr KernelInfoGetAttribute_tensor; + public IntPtr HasSessionConfigEntry; + public IntPtr GetSessionConfigEntry; + public IntPtr SessionOptionsAppendExecutionProvider_Dnnl; + public IntPtr CreateDnnlProviderOptions; + public IntPtr UpdateDnnlProviderOptions; + public IntPtr GetDnnlProviderOptionsAsString; + public IntPtr ReleaseDnnlProviderOptions; + public IntPtr CastTypeInfoToOptionalTypeInfo; + public IntPtr GetOptionalContainedTypeInfo; } internal static class NativeMethods @@ -405,6 +424,16 @@ static NativeMethods() OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions)); OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions)); OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount)); + // MapTypeInfo + OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType)); + OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo)); + OrtGetMapValueType = (DGetMapValueType)Marshal.GetDelegateForFunctionPointer(api_.GetMapValueType, typeof(DGetMapValueType)); + // SequenceTypeInfo + OrtCastTypeInfoToSequenceTypeInfo = (DCastTypeInfoToSequenceTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToSequenceTypeInfo, typeof(DCastTypeInfoToSequenceTypeInfo)); + OrtGetSequenceElementType = (DGetSequenceElementType)Marshal.GetDelegateForFunctionPointer(api_.GetSequenceElementType, typeof(DGetSequenceElementType)); + // Optional Type info + OrtCastTypeInfoToOptionalTypeInfo = (DOrtCastTypeInfoToOptionalTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToOptionalTypeInfo, typeof(DOrtCastTypeInfoToOptionalTypeInfo)); + OrtGetOptionalContainedTypeInfo = (DGetOptionalContainedTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOptionalContainedTypeInfo, typeof(DGetOptionalContainedTypeInfo)); OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue)); OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata)); @@ -1698,6 +1727,41 @@ internal class NativeLib public static DOrtGetTensorShapeElementCount OrtGetTensorShapeElementCount; + /// Map Type API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo); + + public static DCastTypeInfoToMapTypeInfo OrtCastTypeInfoToMapTypeInfo; + + public delegate IntPtr /*(OrtStatus*)*/ DGetMapKeyType(IntPtr /*const OrtMapTypeInfo* */ mapTypeInfo, out IntPtr /*(TensorElementType*)*/ tensorElementType); + + public static DGetMapKeyType OrtGetMapKeyType; + + public delegate IntPtr /*(OrtStatus*)*/ DGetMapValueType(IntPtr /* const OrtMapTypeInfo* */ map_type_info, out IntPtr /* OrtTypeInfo** */ type_info); + + public static DGetMapValueType OrtGetMapValueType; + + // Sequence TypeInfo + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToSequenceTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const OrtSequenceTypeInfo** */ sequenceTypeInfo); + + public static DCastTypeInfoToSequenceTypeInfo OrtCastTypeInfoToSequenceTypeInfo; + + public delegate IntPtr /*(OrtStatus*)*/ DGetSequenceElementType(IntPtr /* const OrtSequenceTypeInfo* */ sequenceTypeInfo, out IntPtr /* OrtTypeInfo** */ elementTypeInfo); + + public static DGetSequenceElementType OrtGetSequenceElementType; + + // OptionalTypeInfo + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); + + public static DOrtCastTypeInfoToOptionalTypeInfo OrtCastTypeInfoToOptionalTypeInfo; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DGetOptionalContainedTypeInfo(IntPtr /* const struct OrtOptionalTypeInfo*/ optTypeInfo, out IntPtr /* struct OrtTypeInfo** */ containedTypeInfo); + + public static DGetOptionalContainedTypeInfo OrtGetOptionalContainedTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseValue(IntPtr /*(OrtValue*)*/ value); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index feff70782f83..0a1ae1912a1e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -43,7 +43,7 @@ public void Dispose() // No need for the finalizer // If this is not disposed timely GC can't help us #endregion - } + } /// /// This helper class contains methods to create native OrtValue from a managed value object @@ -58,8 +58,12 @@ internal static class NativeOnnxValueHelper /// UTF-8 encoded equivalent internal static byte[] StringToZeroTerminatedUtf8(string s) { - byte[] utf8Bytes = UTF8Encoding.UTF8.GetBytes(s); - Array.Resize(ref utf8Bytes, utf8Bytes.Length + 1); + int arraySize = UTF8Encoding.UTF8.GetByteCount(s); + byte[] utf8Bytes = new byte[arraySize + 1]; + if (arraySize != UTF8Encoding.UTF8.GetBytes(s, 0, s.Length, utf8Bytes, 0)) + { + throw new OnnxRuntimeException(ErrorCode.RuntimeException, "Failed to convert to UTF8"); + } utf8Bytes[utf8Bytes.Length - 1] = 0; return utf8Bytes; } @@ -72,7 +76,7 @@ internal static byte[] StringToZeroTerminatedUtf8(string s) /// internal static string StringFromNativeUtf8(IntPtr nativeUtf8) { - // .NET 5.0 has Marshal.PtrToStringUTF8 that does the below + // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below int len = 0; while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len; byte[] buffer = new byte[len]; @@ -80,6 +84,14 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8) return Encoding.UTF8.GetString(buffer, 0, buffer.Length); } + internal static string StringFromUtf8Span(ReadOnlySpan utf8Span) + { + // For now we have to copy into byte[], this produces a copy + // Converting from span is available in later versions + var utf8Bytes = utf8Span.ToArray(); + return Encoding.UTF8.GetString(utf8Bytes, 0, utf8Bytes.Length); + } + /// /// Run helper /// @@ -126,7 +138,7 @@ public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, ou { bool result = true; TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType); - if(typeInfo != null) + if (typeInfo != null) { type = typeInfo.TensorType; width = typeInfo.TypeSize; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 08609bb4826a..c279c31a3770 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -20,6 +20,7 @@ public enum OnnxValueType ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor + ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (but unknown) } /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs index bb7eea2ad188..0bc5ca7240e6 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs @@ -71,7 +71,7 @@ public Float16(ushort v) /// /// instance of Float16 /// value member - public static implicit operator ushort (Float16 f) { return f.value; } + public static implicit operator ushort(Float16 f) { return f.value; } /// /// Converts a 16-bit unsigned integer to a Float16. /// @@ -191,7 +191,7 @@ public bool Equals(BFloat16 other) /// represent the same type and value. /// /// An System.Object. - /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. + /// true if obj is BFloat16 its value is equal to this instance; otherwise, false. public override bool Equals(object obj) { bool result = false; @@ -286,7 +286,8 @@ public class TensorBase private static readonly Dictionary tensorElementTypeInfoMap; - static TensorBase () { + static TensorBase() + { typeInfoMap = new Dictionary() { { typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) }, @@ -306,11 +307,11 @@ public class TensorBase }; tensorElementTypeInfoMap = new Dictionary(); - foreach(var info in typeInfoMap) + foreach (var info in typeInfoMap) { tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize)); } - } + } private readonly Type _primitiveType; /// @@ -559,7 +560,10 @@ internal static T Zero { return (T)(object)(ushort)(0); } - + else if (typeof(T) == typeof(string)) + { + return (T)(object)("0"); + } throw new NotSupportedException(); } } @@ -619,8 +623,8 @@ internal static T One else if (typeof(T) == typeof(ushort)) { return (T)(object)(ushort)(1); - } - else if(typeof(T) == typeof(Float16)) + } + else if (typeof(T) == typeof(Float16)) { return (T)(object)(ushort)(15360); } @@ -628,6 +632,10 @@ internal static T One { return (T)(object)(ushort)(16256); } + else if (typeof(T) == typeof(string)) + { + return (T)(object)("1"); + } throw new NotSupportedException(); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 357d1ca8621b..4ec453eef851 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -1322,8 +1322,29 @@ private void TestModelSequenceOfMapIntFloat() { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var label_meta = outMeta["label"]; + Assert.True(label_meta.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, label_meta.OnnxValueType); + Assert.Equal(TensorElementType.Int64, label_meta.ElementDataType); + Assert.NotEmpty(label_meta.Dimensions); + + // sequence> + var probabilities_meta = outMeta["probabilities"]; + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, probabilities_meta.OnnxValueType); + var seqElementMetata = probabilities_meta.AsSequenceMetadata().ElementMeta; + Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, seqElementMetata.OnnxValueType); + var mapMetadata = seqElementMetata.AsMapMetadata(); + // Map + Assert.Equal(Tensors.TensorElementType.Int64, mapMetadata.KeyDataType); + var valueTensorMeta = mapMetadata.ValueMetadata; + Assert.True(valueTensorMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, valueTensorMeta.ElementDataType); + + // tensor + var inputMeta = session.InputMetadata["input"]; + Assert.True(inputMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, inputMeta.ElementDataType); + Assert.Equal(2, inputMeta.Dimensions.Length); var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); @@ -1392,8 +1413,31 @@ private void TestModelSequenceOfMapStringFloat() using (var session = new InferenceSession(model)) { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outMeta["label"].OnnxValueType); - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["probabilities"].OnnxValueType); + var label_meta = outMeta["label"]; + Assert.True(label_meta.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, label_meta.OnnxValueType); + Assert.True(label_meta.IsString); + Assert.Equal(TensorElementType.String, label_meta.ElementDataType); + Assert.NotEmpty(label_meta.Dimensions); + + // sequence> + var probabilities_meta = outMeta["probabilities"]; + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, probabilities_meta.OnnxValueType); + var seqElementMetata = probabilities_meta.AsSequenceMetadata().ElementMeta; + Assert.Equal(OnnxValueType.ONNX_TYPE_MAP, seqElementMetata.OnnxValueType); + var mapMetadata = seqElementMetata.AsMapMetadata(); + Assert.Equal(Tensors.TensorElementType.String, mapMetadata.KeyDataType); + var valueTensorMeta = mapMetadata.ValueMetadata; + Assert.True(valueTensorMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Float, valueTensorMeta.ElementDataType); + + + // tensor + var inputMeta = session.InputMetadata["input"]; + Assert.True(inputMeta.IsTensor); + Assert.False(inputMeta.IsString); + Assert.Equal(Tensors.TensorElementType.Float, inputMeta.ElementDataType); + Assert.Equal(2, inputMeta.Dimensions.Length); var container = new List(); var tensorIn = new DenseTensor(new float[] { 5.8f, 2.8f }, new int[] { 1, 2 }); @@ -1415,7 +1459,7 @@ private void TestModelSequenceOfMapStringFloat() // Label 1 should have highest probability Assert.Equal("1", outLabelTensor[0]); - // second output is a sequence> + // second output is a sequence> // try-cast to an sequence of NOV var outNode1 = outputs.ElementAtOrDefault(1); Assert.Equal("probabilities", outNode1.Name); @@ -1443,7 +1487,18 @@ private void TestModelSequenceOfTensors() using (var session = new InferenceSession(model)) { var outMeta = session.OutputMetadata; - Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outMeta["output_sequence"].OnnxValueType); + var output_seq = outMeta["output_sequence"]; + Assert.False(output_seq.IsTensor); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, output_seq.OnnxValueType); + var elemMeta = output_seq.AsSequenceMetadata().ElementMeta; + Assert.True(elemMeta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Int64, elemMeta.ElementDataType); + + // Inputs + var tensor1Meta = session.InputMetadata["tensor1"]; + Assert.True(tensor1Meta.IsTensor); + Assert.Equal(Tensors.TensorElementType.Int64, tensor1Meta.ElementDataType); + Assert.Equal(2, tensor1Meta.Dimensions.Length); var container = new List(); var firstInputTensor = new DenseTensor(new Int64[] { 1, 2, 3, 4, 5, 6 }, new int[] { 2, 3 }); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs index 2086fa0ec316..d913aaa5b966 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs @@ -3,6 +3,7 @@ using System.IO; using System.Linq; using Microsoft.ML.OnnxRuntime.Tensors; +using Xunit; namespace Microsoft.ML.OnnxRuntime.Tests { @@ -46,27 +47,9 @@ internal static float[] LoadTensorFromEmbeddedResource(string path) return tensorData.ToArray(); } - internal static void GetTypeAndWidth(Tensors.TensorElementType elemType, out Type type, out int width) - { - TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); - if (result != null) - { - type = result.TensorType; - width = result.TypeSize; - } - else - { - throw new ArgumentException("Unable to get information for type: " + elemType.ToString()); - } - } - static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary nodeMetaDict) { - Type tensorElemType = null; - int width = 0; - GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width); var intDims = new int[tensor.Dims.Count]; - for (int i = 0; i < tensor.Dims.Count; i++) { intDims[i] = (int)tensor.Dims[i]; @@ -86,6 +69,8 @@ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary< { nodeMeta = nodeMetaDict[tensor.Name]; nodeName = tensor.Name; + if (!nodeMeta.IsTensor) + throw new Exception("LoadTensorFromFile can load Tensor types only: " + nodeName); } else { @@ -94,7 +79,10 @@ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary< foreach (var key in nodeMetaDict.Keys) { var meta = nodeMetaDict[key]; - if (tensorElemType == meta.ElementType && tensor.Dims.Count == meta.Dimensions.Length) + if (!meta.IsTensor) + throw new Exception("LoadTensorFromFile can load Tensor types only"); + + if ((Tensors.TensorElementType)tensor.DataType == meta.ElementDataType && tensor.Dims.Count == meta.Dimensions.Length) { int i = 0; for (; i < meta.Dimensions.Length; i++) @@ -126,11 +114,14 @@ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary< throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements"); } - if (!nodeMeta.IsTensor) - throw new Exception("LoadTensorFromFile can load Tensor types only"); + if (nodeMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + throw new Exception("LoadTensorFromFile can load Dense Tensor types only"); - if (tensorElemType != nodeMeta.ElementType) - throw new Exception($"{nameof(tensorElemType)} is expected to be equal to {nameof(nodeMeta.ElementType)}"); + var protoDt = (Tensors.TensorElementType)tensor.DataType; + if (!((protoDt == nodeMeta.ElementDataType) || + (protoDt == TensorElementType.UInt16 && + (nodeMeta.ElementDataType == TensorElementType.BFloat16 || nodeMeta.ElementDataType == TensorElementType.Float16)))) + throw new Exception($"{tensor.DataType.ToString()} is expected to be equal to: " + nodeMeta.ElementDataType.ToString()); if (nodeMeta.Dimensions.Length != tensor.Dims.Count) throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}"); @@ -141,62 +132,39 @@ static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary< throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]"); } - if (nodeMeta.ElementType == typeof(float)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims); - } - else if (nodeMeta.ElementType == typeof(double)) + var elementType = nodeMeta.ElementDataType; + switch (elementType) { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims); - } - else if (nodeMeta.ElementType == typeof(int)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims); - } - else if (nodeMeta.ElementType == typeof(uint)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims); - } - else if (nodeMeta.ElementType == typeof(long)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims); - } - else if (nodeMeta.ElementType == typeof(ulong)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims); - } - else if (nodeMeta.ElementType == typeof(short)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims); - } - else if (nodeMeta.ElementType == typeof(ushort)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); - } - else if (nodeMeta.ElementType == typeof(byte)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims); - } - else if (nodeMeta.ElementType == typeof(sbyte)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(sbyte), intDims); - } - else if (nodeMeta.ElementType == typeof(bool)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims); - } - else if (nodeMeta.ElementType == typeof(Float16)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); - } - else if (nodeMeta.ElementType == typeof(BFloat16)) - { - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); - } - else - { - //TODO: Add support for remaining types - throw new Exception($"Tensors of type {nameof(nodeMeta.ElementType)} not currently supported in the LoadTensorFromEmbeddedResource"); + case TensorElementType.Float: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims); + case TensorElementType.Double: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims); + case TensorElementType.Int32: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims); + case TensorElementType.UInt32: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims); + case TensorElementType.Int16: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims); + case TensorElementType.UInt16: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + case TensorElementType.Int64: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims); + case TensorElementType.UInt64: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims); + case TensorElementType.UInt8: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims); + case TensorElementType.Int8: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(sbyte), intDims); + case TensorElementType.Bool: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims); + case TensorElementType.Float16: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + case TensorElementType.BFloat16: + return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + case TensorElementType.String: + return CreateNamedOnnxValueFromString(tensor, intDims); + default: + throw new Exception($"Tensors of type: " + nodeMeta.ElementType.ToString() + " not currently supported in the LoadTensorFromEmbeddedResource"); } } @@ -250,6 +218,23 @@ internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, b return NamedOnnxValue.CreateFromTensor(name, dt); } + internal static NamedOnnxValue CreateNamedOnnxValueFromString(Onnx.TensorProto tensor, int[] dimensions) + { + if (tensor.DataType != (int)Onnx.TensorProto.Types.DataType.String) + { + throw new ArgumentException("Expecting string data"); + } + + string[] strArray = new string[tensor.StringData.Count]; + for (int i = 0; i < tensor.StringData.Count; ++i) + { + strArray[i] = System.Text.Encoding.UTF8.GetString(tensor.StringData[i].ToByteArray()); + } + + var dt = new DenseTensor(strArray, dimensions); + return NamedOnnxValue.CreateFromTensor(tensor.Name, dt); + } + internal static float[] LoadTensorFromFile(string filename, bool skipheader = true) { var tensorData = new List(); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 61ff2a43ffd4..351a80e86854 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -8,27 +8,27 @@ namespace Microsoft.ML.OnnxRuntime.Tests { - /// - /// This is compensate for the absence of string.Contains() in .NET Standard 2.0 - /// Contains(String, StringComparison) - /// - public static class StringExtensions - { - public static bool Contains(this String str, String substring, - StringComparison comp) + /// + /// This is compensate for the absence of string.Contains() in .NET Standard 2.0 + /// Contains(String, StringComparison) + /// + public static class StringExtensions { - if (substring == null) - throw new ArgumentNullException("substring", - "substring cannot be null."); - else if (!Enum.IsDefined(typeof(StringComparison), comp)) - throw new ArgumentException("comp is not a member of StringComparison", - "comp"); - - return str.IndexOf(substring, comp) >= 0; + public static bool Contains(this String str, String substring, + StringComparison comp) + { + if (substring == null) + throw new ArgumentNullException("substring", + "substring cannot be null."); + else if (!Enum.IsDefined(typeof(StringComparison), comp)) + throw new ArgumentException("comp is not a member of StringComparison", + "comp"); + + return str.IndexOf(substring, comp) >= 0; + } } - } - public partial class InferenceTest - { + public partial class InferenceTest + { private const string module = "onnxruntime.dll"; private const string propertiesFile = "Properties.txt"; @@ -256,21 +256,17 @@ private void TestTensorRTProviderOptions() { "cntk_simple_seg", "Bad onnx test output caused by wrong SAME_UPPER/SAME_LOWER for ConvTranspose" }, { "coreml_Imputer-LogisticRegression_sklearn_load_breast_cancer", "Can't determine model file name" }, { "mask_rcnn_keras", "Model should be edited to remove the extra outputs" }, - { "test_strnormalizer_export_monday_casesensintive_lower", "ElementType not currently supported"}, { "test_max_float64", "node test error"}, { "test_min_uint8", "node test error"}, { "test_mod_mixed_sign_float64", "node test error"}, { "test_momentum", "node test error"}, { "test_max_uint16", "node test error"}, { "test_resize_downsample_scales_linear_align_corners", "node test error"}, - { "test_strnormalizer_nostopwords_nochangecase", "node test error"}, { "test_adagrad_multiple", "node test error"}, { "test_einsum_inner_prod", "node test error"}, { "test_sequence_insert_at_back", "node test error"}, { "test_mod_mixed_sign_int8", "node test error"}, { "test_maxunpool_export_with_output_shape", "node test error"}, - { "test_strnormalizer_export_monday_empty_output", "node test error"}, - { "test_strnormalizer_export_monday_insensintive_upper_twodim", "ElementType not currently supported"}, { "test_min_int16", "node test error"}, { "test_adagrad", "node test error"}, { "test_min_float64", "node test error"}, @@ -283,19 +279,18 @@ private void TestTensorRTProviderOptions() { "test_clip_default_int8_inbounds", "node test error"}, { "test_eyelike_with_dtype", "node test error"}, { "test_cast_STRING_to_FLOAT", "node test error"}, - { "test_cast_FLOAT16_to_DOUBLE", "node test error"}, { "test_cast_FLOAT_to_DOUBLE", "node test error"}, { "test_cast_BFLOAT16_to_FLOAT", "node test error"}, { "test_cast_FLOAT_to_BFLOAT16", "node test error"}, - { "test_cast_FLOAT_to_STRING", "node test error"}, + { "test_cast_FLOAT_to_STRING", "Output strings can not be compared exactly"}, { "test_castlike_STRING_to_FLOAT", "node test error"}, { "test_castlike_STRING_to_FLOAT_expanded", "node test error"}, { "test_castlike_FLOAT16_to_DOUBLE", "node test error"}, { "test_castlike_FLOAT16_to_DOUBLE_expanded", "node test error"}, { "test_castlike_FLOAT_to_DOUBLE", "node test error"}, { "test_castlike_FLOAT_to_DOUBLE_expanded", "node test error"}, - { "test_castlike_BFLOAT16_to_FLOAT", "node test error"}, - { "test_castlike_BFLOAT16_to_FLOAT_expanded", "node test error"}, + { "test_castlike_BFLOAT16_to_FLOAT", "Length is expected to be equal to Count (metadata and expected data mismatch) "}, + { "test_castlike_BFLOAT16_to_FLOAT_expanded", "Length is expected to be equal to Count metadata and expected data mismatch"}, { "test_castlike_FLOAT_to_BFLOAT16", "node test error"}, { "test_castlike_FLOAT_to_BFLOAT16_expanded", "node test error"}, { "test_castlike_FLOAT_to_STRING", "node test error"}, @@ -304,14 +299,12 @@ private void TestTensorRTProviderOptions() { "test_bitshift_left_uint16", "node test error"}, { "test_pow_types_float32_uint64", "node test error"}, { "test_max_uint8", "node test error"}, - { "test_strnormalizer_export_monday_casesensintive_nochangecase", "ElementType not currently supported"}, { "test_momentum_multiple", "node test error"}, { "test_pow_types_float32_uint32", "node test error"}, { "test_if_seq", "sequence type is not supported in test infra."}, { "test_resize_downsample_scales_cubic_align_corners", "node test error"}, { "test_einsum_batch_matmul", "node test error"}, { "test_nesterov_momentum", "node test error"}, - { "test_strnormalizer_export_monday_casesensintive_upper", "node test error"}, { "test_min_uint16", "node test error"}, { "test_adam_multiple", "node test error"}, { "test_loop13_seq", "sequence type is not supported in test infra." }, @@ -538,68 +531,13 @@ private void TestPreTrainedModels(string opsetDir, string modelName) { outputValue = outputContainer.First(); // in case the output data file does not contain the name } - if (outputMeta.IsTensor) + if (outputMeta.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) // Only Dense tensors now { - if (outputMeta.ElementType == typeof(float)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new FloatComparer()); - } - else if (outputMeta.ElementType == typeof(double)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new DoubleComparer()); - } - else if (outputMeta.ElementType == typeof(int)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(uint)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(short)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(ushort)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(long)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(ulong)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(byte)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(sbyte)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(bool)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); - } - else if (outputMeta.ElementType == typeof(Float16)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 }); - } - else if (outputMeta.ElementType == typeof(BFloat16)) - { - Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 }); - } - else - { - Assert.True(false, $"{nameof(TestPreTrainedModels)} does not yet support output of type {outputMeta.ElementType}"); - } + VerifyTensorResults(outputMeta.ElementDataType, result, outputValue); } else { - Assert.True(false, $"{nameof(TestPreTrainedModels)} cannot handle non-tensor outputs yet"); + Assert.True(false, "TestPreTrainedModels cannot handle Onnxtype: " + outputMeta.OnnxValueType.ToString()); } } } @@ -624,6 +562,58 @@ private void TestPreTrainedModels(string opsetDir, string modelName) } } + private void VerifyTensorResults(TensorElementType elementType, DisposableNamedOnnxValue result, NamedOnnxValue outputValue) + { + switch (elementType) + { + case TensorElementType.Float: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new FloatComparer()); + break; + case TensorElementType.Double: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new DoubleComparer()); + break; + case TensorElementType.Int32: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt32: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int64: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt64: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.UInt8: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Int8: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Bool: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + case TensorElementType.Float16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new Float16Comparer { tolerance = 2 }); + break; + case TensorElementType.BFloat16: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new BFloat16Comparer { tolerance = 2 }); + break; + case TensorElementType.String: + Assert.Equal(result.AsTensor(), outputValue.AsTensor(), new ExactComparer()); + break; + default: + Assert.True(false, "TestPreTrainedModels does not yet support output of type: " + elementType.ToString()); + break; + } + } + // Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle private void UnloadLibrary(IntPtr libraryHandle) { @@ -669,7 +659,8 @@ private void TestRegisterCustomOpLibrary() var ortEnvInstance = OrtEnv.Instance(); string[] providers = ortEnvInstance.GetAvailableProviders(); - if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) { + if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) + { option.AppendExecutionProvider_CUDA(0); } diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs index 4f29d72b0b14..9370a03f7fbe 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Program.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using CommandLine; -using Google.Protobuf; using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b56497ea3231..f4975f0047d8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -263,10 +263,11 @@ ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); -ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(MapTypeInfo); ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(OptionalTypeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(ModelMetadata); ORT_RUNTIME_CLASS(ThreadPoolParams); ORT_RUNTIME_CLASS(ThreadingOptions); @@ -563,9 +564,7 @@ typedef struct OrtMIGraphXProviderOptions { */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, - num_of_threads{}, cache_dir{}, - context{}, enable_opencl_throttling{}, enable_dynamic_shapes{} {} + OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, context{}, enable_opencl_throttling{}, enable_dynamic_shapes{} {} #endif /** \brief Device type string * @@ -574,8 +573,8 @@ typedef struct OrtOpenVINOProviderOptions { const char* device_type; unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; - size_t num_of_threads; ///< 0 = Use default number of threads - const char* cache_dir; // path is set to empty by default + size_t num_of_threads; ///< 0 = Use default number of threads + const char* cache_dir; // path is set to empty by default void* context; unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled @@ -1325,8 +1324,9 @@ struct OrtApi { * * \param[in] type_info * \param[out] out Do not free this value, it will be valid until type_info is freed. + * If type_info does not represent tensor, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); @@ -1835,9 +1835,10 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[out] type_info - * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value + * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info + * does not contain a map, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtMapTypeInfo** out); @@ -1850,9 +1851,10 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] type_info - * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value + * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info + * doesn not contain a sequence, this value will be set to nullptr. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. */ ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); @@ -3917,7 +3919,7 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); - /** \brief Create an OrtDnnlProviderOptions + /** \brief Create an OrtDnnlProviderOptions * * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions * @@ -4085,6 +4087,48 @@ struct OrtApi { * \since Version 1.15. */ ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + + /** \brief Get Optional Type information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. + * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by ::OrtOptionalTypeInfo. + * + * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied + * in place of the optional type when creating the actual ::OrtValue). + * + * \param[in] type_info + * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, + * it is owned by OrtTypeInfo instance. When the type_info does not represent + * optional type, nullptr is returned in out. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Always returns nullptr. + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + + /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. + * + * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. + * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by the returned ::OrtTypeInfo. + * + * \param[in] optional_type_info + * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. + * it is owned by OrtOptionalTypeInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5af45b8ff38e..2086193f0c39 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,6 +890,19 @@ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } }; +namespace detail { +template +struct OptionalTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo +}; + +} // namespace detail + +// This is always owned by the TypeInfo and can only be obtained from it. +using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; + namespace detail { template struct MapTypeInfoImpl : detail::Base { @@ -921,6 +934,7 @@ struct TypeInfoImpl : detail::Base { ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo ONNXType GetONNXType() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8e2d845d2f62..899c6c331a2c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1074,9 +1074,6 @@ inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { return out; } -} // namespace detail - -namespace detail { template inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { const OrtTensorTypeAndShapeInfo* out; @@ -1105,9 +1102,6 @@ inline ONNXType TypeInfoImpl::GetONNXType() const { return out; } -} // namespace detail - -namespace detail { template inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { OrtTypeInfo* output; @@ -1115,9 +1109,13 @@ inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { return TypeInfo{output}; } -} // namespace detail +template +inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { + OrtTypeInfo* info; + ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); + return TypeInfo{info}; +} -namespace detail { template inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { ONNXTensorElementDataType out; @@ -1131,6 +1129,14 @@ inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); return TypeInfo{output}; } + +template +inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { + const OrtOptionalTypeInfo* info; + ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); + return ConstOptionalTypeInfo{info}; +} + } // namespace detail namespace detail { diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 9b18ba670369..bcf925dce48e 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -6,9 +6,12 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type) noexcept : map_key_type_(map_key_type), map_value_type_(map_value_type, &OrtApis::ReleaseTypeInfo) { +OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept + : map_key_type_(map_key_type), map_value_type_(std::move(map_value_type)) { } +OrtMapTypeInfo::~OrtMapTypeInfo() = default; + static ONNXTensorElementDataType ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { using TensorType = ONNX_NAMESPACE::TensorProto_DataType; @@ -35,36 +38,26 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif -OrtStatus* OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtMapTypeInfo** out) { - auto value_case = type_proto->value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) - { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type map!");; +OrtMapTypeInfo::Ptr OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { + + auto value_case = type_proto.value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) { + ORT_THROW("type_proto is not of type map!"); } // Get the key type of the map - auto type_proto_map = type_proto->map_type(); - auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); + const auto& type_proto_map = type_proto.map_type(); + const auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); // Get the value type of the map - OrtTypeInfo* map_value_type_info = nullptr; - if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_map.value_type(), &map_value_type_info)) - { - return status; - } + auto map_value_type_info = OrtTypeInfo::FromTypeProto(type_proto_map.value_type()); - *out = new OrtMapTypeInfo(map_key_type, map_value_type_info); - return nullptr; + return std::make_unique(map_key_type, std::move(map_value_type_info)); } -OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { - OrtTypeInfo* map_value_type_copy = nullptr; - if (auto status = map_value_type_->Clone(&map_value_type_copy)) - { - return status; - } - *out = new OrtMapTypeInfo(map_key_type_, map_value_type_copy); - return nullptr; +OrtMapTypeInfo::Ptr OrtMapTypeInfo::Clone() const { + auto map_value_type_copy = map_value_type_->Clone(); + return std::make_unique(map_key_type_, std::move(map_value_type_copy)); } // OrtMapTypeInfo Accessors @@ -78,10 +71,12 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_ ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN - return map_type_info->map_value_type_->Clone(out); + auto clone = map_type_info->map_value_type_->Clone(); + *out = clone.release(); + return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { - delete ptr; + OrtMapTypeInfo::Ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h index 46477d8f04fa..9a72be3db490 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -12,15 +12,20 @@ class TypeProto; struct OrtMapTypeInfo { public: + + using Ptr = std::unique_ptr; + ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; - std::unique_ptr map_value_type_; + std::unique_ptr map_value_type_; + + static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtMapTypeInfo** out); + Ptr Clone() const; - OrtStatus* Clone(OrtMapTypeInfo** out); + OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept; + ~OrtMapTypeInfo(); private: - OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, OrtTypeInfo* map_value_type)noexcept; OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.cc b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc new file mode 100644 index 000000000000..0ad5fc1a9ca2 --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/framework/onnxruntime_optional_type_info.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/ort_apis.h" +#include "core/framework/error_code_helper.h" + +OrtOptionalTypeInfo::OrtOptionalTypeInfo(OrtTypeInfo::Ptr contained_type) noexcept + : contained_type_(std::move(contained_type)) { +} + +OrtOptionalTypeInfo::~OrtOptionalTypeInfo() = default; + +OrtOptionalTypeInfo::Ptr OrtOptionalTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { + const auto value_case = type_proto.value_case(); + + if (value_case != ONNX_NAMESPACE::TypeProto::kOptionalType) { + ORT_THROW("type_proto is not of optional type"); + } + + const auto& type_proto_optional = type_proto.optional_type(); + auto contained_type_info = OrtTypeInfo::FromTypeProto(type_proto_optional.elem_type()); + + return std::make_unique(std::move(contained_type_info)); +} + +OrtOptionalTypeInfo::Ptr OrtOptionalTypeInfo::Clone() const { + auto contained_type_copy = contained_type_->Clone(); + return std::make_unique(std::move(contained_type_copy)); +} + + ORT_API_STATUS_IMPL(OrtApis::GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out) { + API_IMPL_BEGIN + auto type_info = optional_type_info->contained_type_->Clone(); + *out = type_info.release(); + return nullptr; + API_IMPL_END + } diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.h b/onnxruntime/core/framework/onnxruntime_optional_type_info.h new file mode 100644 index 000000000000..9a44839f110f --- /dev/null +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +#include "core/framework/onnxruntime_typeinfo.h" + +namespace ONNX_NAMESPACE { +class TypeProto; +} + +struct OrtOptionalTypeInfo { + + using Ptr = std::unique_ptr; + + explicit OrtOptionalTypeInfo(OrtTypeInfo::Ptr contained_type) noexcept; + ~OrtOptionalTypeInfo(); + + OrtTypeInfo::Ptr contained_type_; + + Ptr Clone() const; + + static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + + OrtOptionalTypeInfo(const OrtOptionalTypeInfo& other) = delete; + OrtOptionalTypeInfo& operator=(const OrtOptionalTypeInfo& other) = delete; +}; diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc index acae583c7a21..4022aa6e4f1a 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -6,43 +6,44 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept : sequence_key_type_(sequence_key_type, &OrtApis::ReleaseTypeInfo) { +OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo::Ptr sequence_key_type) noexcept + : sequence_key_type_(std::move(sequence_key_type)) { } + +OrtSequenceTypeInfo::~OrtSequenceTypeInfo() = default; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif -OrtStatus* OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* type_proto, OrtSequenceTypeInfo** out) { - auto value_case = type_proto->value_case(); + + OrtSequenceTypeInfo::Ptr OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { + + const auto value_case = type_proto.value_case(); + if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_proto is not of type sequence!"); + ORT_THROW("type_proto is not of type sequence!"); } - auto type_proto_sequence = type_proto->sequence_type(); - OrtTypeInfo* sequence_key_type_info = nullptr; - if (auto status = OrtTypeInfo::FromTypeProto(&type_proto_sequence.elem_type(), &sequence_key_type_info)) { - return status; - } + const auto& type_proto_sequence = type_proto.sequence_type(); + auto key_type_info = OrtTypeInfo::FromTypeProto(type_proto_sequence.elem_type()); - *out = new OrtSequenceTypeInfo(sequence_key_type_info); - return nullptr; + return std::make_unique(std::move(key_type_info)); } -OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { - OrtTypeInfo* sequence_key_type_copy = nullptr; - if (auto status = sequence_key_type_->Clone(&sequence_key_type_copy)) { - return status; - } - *out = new OrtSequenceTypeInfo(sequence_key_type_copy); - return nullptr; +OrtSequenceTypeInfo::Ptr OrtSequenceTypeInfo::Clone() const { + auto key_type_copy = sequence_key_type_->Clone(); + return std::make_unique(std::move(key_type_copy)); } ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN - return sequence_type_info->sequence_key_type_->Clone(out); + auto key_type_copy = sequence_type_info->sequence_key_type_->Clone(); + *out = key_type_copy.release(); + return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo* ptr) { - delete ptr; + OrtSequenceTypeInfo::Ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h index 5378dd578abe..dc016f44c380 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h @@ -2,25 +2,28 @@ // Licensed under the MIT License. #pragma once -#include "onnxruntime_c_api.h" - #include +#include "core/framework/onnxruntime_typeinfo.h" + namespace ONNX_NAMESPACE { class TypeProto; } struct OrtSequenceTypeInfo { public: - explicit OrtSequenceTypeInfo(OrtTypeInfo* sequence_key_type) noexcept; - std::unique_ptr sequence_key_type_; + using Ptr = std::unique_ptr; + + explicit OrtSequenceTypeInfo(OrtTypeInfo::Ptr sequence_key_type) noexcept; + ~OrtSequenceTypeInfo(); + + OrtTypeInfo::Ptr sequence_key_type_; - OrtStatus* Clone(OrtSequenceTypeInfo** out); + Ptr Clone() const; - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtSequenceTypeInfo** out); + static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); - private: OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete; OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete; }; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index e3a07f84ef32..833e2da3ca42 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -//this file contains implementations of the C API +// this file contains implementations of the C API #include #include "onnxruntime_typeinfo.h" @@ -15,6 +15,7 @@ #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/onnxruntime_optional_type_info.h" #include "core/framework/TensorSeq.h" using onnxruntime::BFloat16; @@ -27,55 +28,64 @@ using onnxruntime::Tensor; using onnxruntime::TensorShape; namespace on = ONNX_NAMESPACE; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif + OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { } -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtTensorTypeAndShapeInfo* data1) noexcept : type(type1), data(data1) { -} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr map_type_info1) noexcept + : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info1)) {} -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtMapTypeInfo* map_type_info1) noexcept : type(type1), map_type_info(map_type_info1) { -} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr sequence_type_info1) noexcept + : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info1)) {} -OrtTypeInfo::OrtTypeInfo(ONNXType type1, OrtSequenceTypeInfo* sequence_type_info1) noexcept : type(type1), sequence_type_info(sequence_type_info1) { -} - -OrtTypeInfo::~OrtTypeInfo() { - OrtApis::ReleaseTensorTypeAndShapeInfo(data); +OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info1) noexcept + : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info1)) {} - if (map_type_info) { - OrtApis::ReleaseMapTypeInfo(map_type_info); - } - if (sequence_type_info) { - OrtApis::ReleaseSequenceTypeInfo(sequence_type_info); - } +OrtTypeInfo::OrtTypeInfo(ONNXType type1, std::unique_ptr data1) noexcept + : type(type1), data(std::move(data1)) { } +OrtTypeInfo::~OrtTypeInfo() = default; + ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, _Out_ ONNXType* out) { + API_IMPL_BEGIN *out = input->type; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, - _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data : nullptr; + _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { + API_IMPL_BEGIN + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { + _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { API_IMPL_BEGIN - *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr; + *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info.get() : nullptr; return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { API_IMPL_BEGIN - *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr; + *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info.get() : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out) { + API_IMPL_BEGIN + *out = (type_info->type != ONNX_TYPE_OPTIONAL) ? type_info->optional_type_info.get() : nullptr; return nullptr; API_IMPL_END } @@ -90,19 +100,13 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* } ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } -OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const onnxruntime::DataTypeImpl& tensor_data_type, - OrtTensorTypeAndShapeInfo** out); -OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out); - -OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { +OrtTypeInfo::Ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { onnxruntime::MLDataType type = value.Type(); if (type == nullptr) { - *out = new OrtTypeInfo(ONNX_TYPE_UNKNOWN); - return nullptr; + return MakePtr(ONNX_TYPE_UNKNOWN); } // GetType and GetType do not have TypeProto populated because they return a static @@ -110,50 +114,39 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // unless they are primitive data types, in which case we as before return them not implemented // however, this way we can support Opaque and we can avoid excessive calls to GetType() if (type->IsTensorType()) { - OrtTensorTypeAndShapeInfo* info = nullptr; const Tensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info); - if (st != nullptr) - return st; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type); + return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); } - *out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); - return nullptr; + return MakePtr(ONNX_TYPE_TENSOR); } if (type->IsSparseTensorType()) { #if !defined(DISABLE_SPARSE_TENSORS) - OrtTensorTypeAndShapeInfo* info = nullptr; const SparseTensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type, &info); - if (st != nullptr) return st; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type); + return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape)); } - *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info); - return nullptr; + return MakePtr(ONNX_TYPE_SPARSETENSOR); #else - return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build."); + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif } if (type->IsTensorSequenceType()) { - OrtTensorTypeAndShapeInfo* info = nullptr; const auto* tensor_data_type = value.Get().DataType(); if (tensor_data_type != nullptr) { TensorShape void_shape = {}; - OrtStatus* st = GetTensorShapeAndType(void_shape, *tensor_data_type, &info); - if (st != nullptr) { - return st; - } - - auto element_type_info = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); - auto sequence_type_info = new OrtSequenceTypeInfo(element_type_info); - *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); - return nullptr; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); + auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); + auto sequence_type_info = std::make_unique(std::move(type_info)); + return MakePtr(std::move(sequence_type_info)); } else { - return OrtApis::CreateStatus(ORT_FAIL, "OrtValue is TensorSequence type but has no element Tensor DataType."); + ORT_THROW("OrtValue is TensorSequence type but has no element Tensor DataType."); } } @@ -162,74 +155,53 @@ OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { - *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE); - return nullptr; - } -#if !defined(DISABLE_ML_OPS) + return MakePtr(ONNX_TYPE_OPAQUE); + } break; case on::TypeProto::kMapType: { - return OrtTypeInfo::FromTypeProto(type_proto, out); - } +#if !defined(DISABLE_ML_OPS) + auto map_type_info = OrtMapTypeInfo::FromTypeProto(*type_proto); + return MakePtr(std::move(map_type_info)); + } break; +#else + ORT_NOT_IMPLEMENTED("Map types are not supported in this build"); #endif case on::TypeProto::kSequenceType: { - return OrtTypeInfo::FromTypeProto(type_proto, out); - } + auto seq_info = OrtSequenceTypeInfo::FromTypeProto(*type_proto); + return MakePtr(std::move(seq_info)); + } break; // Real Tensor support - case on::TypeProto::kTensorType: #if !defined(DISABLE_SPARSE_TENSORS) - case on::TypeProto::kSparseTensorType: { - return OrtApis::CreateStatus(ORT_FAIL, "Tensor types should have been handled already"); - } + case on::TypeProto::kSparseTensorType: + [[fallthrough]]; +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); #endif + case on::TypeProto::kTensorType: { + ORT_THROW("Tensor types should have been handled already"); + } break; default: // NOT_IMPLEMENTED break; } } - - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + ORT_NOT_IMPLEMENTED("This OrtValue is neither Tensor, SparseTensor, Map or Sequence type"); } const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { - switch (type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_STRING: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - return DataTypeImpl::GetType(); - case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: - return DataTypeImpl::GetType(); - - default: - ORT_NOT_IMPLEMENTED(__FUNCTION__, ":tensor type ", type, " is not supported"); - } + auto tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(type); + return tensor_type->GetElementType(); } -OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) { - auto value_case = input->value_case(); +OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& input) { + auto value_case = input.value_case(); switch (value_case) { - case on::TypeProto::kTensorType: - case on::TypeProto::kSparseTensorType: { + case on::TypeProto::kSparseTensorType: +#if !defined(DISABLE_SPARSE_TENSORS) + [[fallthrough]]; +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); +#endif + case on::TypeProto::kTensorType: { ONNXType ten_type = ONNX_TYPE_UNKNOWN; const on::TypeProto_Tensor* tensor_type = nullptr; #if !defined(DISABLE_SPARSE_TENSORS) @@ -237,14 +209,14 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or #endif const on::TensorShapeProto* sp = nullptr; if (value_case == on::TypeProto::kTensorType) { - tensor_type = &input->tensor_type(); + tensor_type = &input.tensor_type(); ten_type = ONNX_TYPE_TENSOR; if (onnxruntime::utils::HasShape(*tensor_type)) { sp = &tensor_type->shape(); } } else if (value_case == on::TypeProto::kSparseTensorType) { #if !defined(DISABLE_SPARSE_TENSORS) - sparse_type = &input->sparse_tensor_type(); + sparse_type = &input.sparse_tensor_type(); ten_type = ONNX_TYPE_SPARSETENSOR; if (onnxruntime::utils::HasShape(*sparse_type)) { sp = &sparse_type->shape(); @@ -252,14 +224,13 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or #endif } - OrtStatus* st = nullptr; - OrtTensorTypeAndShapeInfo* info = nullptr; + OrtTensorTypeAndShapeInfo::Ptr type_shape; if (sp != nullptr) { const on::TensorShapeProto& s = *sp; std::vector dims(s.dim_size()); std::vector dim_params(s.dim_size()); TensorShape shape_data(std::move(dims)); - for (int i = 0; i < s.dim_size(); ++i) { + for (int i = 0, dim_size = s.dim_size(); i < dim_size; ++i) { auto& t = s.dim(i); switch (t.value_case()) { case on::TensorShapeProto::Dimension::kDimValue: @@ -275,97 +246,94 @@ OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, Or assert(false); } } - st = GetTensorShapeAndType(shape_data, &dim_params, *input, &info); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input); } else { - st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input); } - if (st != nullptr) return st; - auto type_info = new OrtTypeInfo(ten_type, info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + + auto type_info = MakePtr(ten_type, std::move(type_shape)); + type_info->denotation = input.denotation(); + return type_info; } break; case on::TypeProto::kSequenceType: { - OrtSequenceTypeInfo* sequence_type_info = nullptr; - - if (auto status = OrtSequenceTypeInfo::FromTypeProto(input, &sequence_type_info)) { - return status; - } - - auto type_info = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, sequence_type_info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + auto sequence_type_info = OrtSequenceTypeInfo::FromTypeProto(input); + auto type_info = MakePtr(std::move(sequence_type_info)); + type_info->denotation = input.denotation(); + return type_info; } break; +#if !defined(DISABLE_ML_OPS) case on::TypeProto::kMapType: { - OrtMapTypeInfo* map_type_info = nullptr; - - if (auto status = OrtMapTypeInfo::FromTypeProto(input, &map_type_info)) { - return status; - } - - auto type_info = new OrtTypeInfo(ONNX_TYPE_MAP, map_type_info); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + auto map_type_info = OrtMapTypeInfo::FromTypeProto(input); + auto type_info = MakePtr(std::move(map_type_info)); + type_info->denotation = input.denotation(); + return type_info; + } break; +#endif + case on::TypeProto::kOptionalType: { + auto optional_type_info = OrtOptionalTypeInfo::FromTypeProto(input); + auto type_info = MakePtr(std::move(optional_type_info)); + type_info->denotation = input.denotation(); + return type_info; } break; case on::TypeProto::kOpaqueType: { - auto type_info = new OrtTypeInfo(ONNX_TYPE_OPAQUE); - type_info->denotation = input->denotation(); - *out = type_info; - return nullptr; + auto type_info = MakePtr(ONNX_TYPE_OPAQUE); + type_info->denotation = input.denotation(); + return type_info; } break; case on::TypeProto::VALUE_NOT_SET: + ORT_THROW("This TypeProto does not have ValueCase set"); break; default: // Not implemented break; } - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); } -OrtStatus* OrtTypeInfo::Clone(OrtTypeInfo** out) { +OrtTypeInfo::Ptr OrtTypeInfo::Clone() const { switch (type) { - case ONNX_TYPE_TENSOR: - case ONNX_TYPE_SPARSETENSOR: { + case ONNX_TYPE_SPARSETENSOR: #if !defined(DISABLE_SPARSE_TENSORS) - OrtTensorTypeAndShapeInfo* clone; - if (auto status = data->Clone(&clone)) { - return status; - } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; + [[fallthrough]]; #else - return OrtApis::CreateStatus(ORT_FAIL, "SparseTensor is not supported in this build."); + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif + case ONNX_TYPE_TENSOR: { + OrtTensorTypeAndShapeInfo::Ptr info; + if (data) { + info = data->Clone(); + } + auto type_info = MakePtr(type, std::move(info)); + type_info->denotation = denotation; + return type_info; } + case ONNX_TYPE_SEQUENCE: { - OrtSequenceTypeInfo* clone; - if (auto status = sequence_type_info->Clone(&clone)) { - return status; - } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; + auto seq_clone = sequence_type_info->Clone(); + auto type_info = MakePtr(std::move(seq_clone)); + type_info->denotation = denotation; + return type_info; } case ONNX_TYPE_MAP: { - OrtMapTypeInfo* clone; - if (auto status = map_type_info->Clone(&clone)) { - return status; - } - *out = new OrtTypeInfo(type, clone); - (*out)->denotation = denotation; - return nullptr; + auto map_clone = map_type_info->Clone(); + auto type_info = MakePtr(std::move(map_clone)); + type_info->denotation = denotation; + return type_info; + } + case ONNX_TYPE_OPTIONAL: { + auto opt_clone = optional_type_info->Clone(); + auto type_info = MakePtr(std::move(opt_clone)); + type_info->denotation = denotation; + return type_info; } case ONNX_TYPE_OPAQUE: { - *out = new OrtTypeInfo(type); - (*out)->denotation = denotation; - return nullptr; + auto type_info = MakePtr(type); + type_info->denotation = denotation; + return type_info; } default: // Not implemented break; } - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); } diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 5b9145d32e28..207df877ff66 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -2,13 +2,16 @@ // Licensed under the MIT License. #pragma once + #include +#include #include +#include + #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class DataTypeImpl; -class TensorShape; } // namespace onnxruntime namespace ONNX_NAMESPACE { @@ -18,33 +21,50 @@ class TypeProto; // These types are only present in the winml adapter c api, so they are forward declared. struct OrtMapTypeInfo; struct OrtSequenceTypeInfo; +struct OrtOptionalTypeInfo; +struct OrtTensorTypeAndShapeInfo; /** * the equivalent of ONNX_NAMESPACE::TypeProto * This class is mainly for the C API */ struct OrtTypeInfo { - public: - ONNXType type = ONNX_TYPE_UNKNOWN; - std::string denotation; + // Provide default construction + using Ptr = std::unique_ptr; - ~OrtTypeInfo(); + ONNXType type; + std::string denotation; - //owned by this - OrtTensorTypeAndShapeInfo* data = nullptr; - OrtMapTypeInfo* map_type_info = nullptr; - OrtSequenceTypeInfo* sequence_type_info = nullptr; - OrtTypeInfo(const OrtTypeInfo& other) = delete; - OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; + std::unique_ptr data; + std::unique_ptr map_type_info; + std::unique_ptr sequence_type_info; + std::unique_ptr optional_type_info; - OrtStatus* Clone(OrtTypeInfo** out); + Ptr Clone() const; - static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out); - static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); + static Ptr FromOrtValue(const OrtValue& value); + static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); - OrtTypeInfo(ONNXType type) noexcept; - OrtTypeInfo(ONNXType type, OrtTensorTypeAndShapeInfo* data) noexcept; - OrtTypeInfo(ONNXType type, OrtMapTypeInfo* map_type_info) noexcept; - OrtTypeInfo(ONNXType type, OrtSequenceTypeInfo* sequence_type_info) noexcept; + explicit OrtTypeInfo(ONNXType type1) noexcept; + + explicit OrtTypeInfo(std::unique_ptr map_type_info1) noexcept; + + OrtTypeInfo(ONNXType type1, std::unique_ptr data1) noexcept; + + explicit OrtTypeInfo(std::unique_ptr sequence_type_info1) noexcept; + + explicit OrtTypeInfo(std::unique_ptr optional_type_info1) noexcept; + + + OrtTypeInfo(const OrtTypeInfo&) = delete; + OrtTypeInfo& operator=(const OrtTypeInfo&) = delete; + + ~OrtTypeInfo(); + + template + static Ptr MakePtr(Args... args) { + return std::make_unique(std::forward(args)...); + } + }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index bebfa72c546a..f2a4f457b68a 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -24,20 +24,23 @@ using onnxruntime::MLFloat16; #if !defined(DISABLE_SPARSE_TENSORS) using onnxruntime::SparseTensor; #endif -using onnxruntime::Tensor; using onnxruntime::narrow; +using onnxruntime::Tensor; + +OrtTensorTypeAndShapeInfo::~OrtTensorTypeAndShapeInfo() = default; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = new OrtTensorTypeAndShapeInfo(); + *out = std::make_unique().release(); return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseTensorTypeAndShapeInfo, _Frees_ptr_opt_ OrtTensorTypeAndShapeInfo* ptr) { - delete ptr; + std::unique_ptr p(ptr); } ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* this_ptr, @@ -151,45 +154,34 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( return TensorDataTypeToOnnxRuntimeTensorElementDataType(prim_type->GetDataType()); } -OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape, - const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out) { - OrtTensorTypeAndShapeInfo* ret; - if (auto* status = OrtApis::CreateTensorTypeAndShapeInfo(&ret)) - return status; - if (auto* status = OrtApis::SetTensorElementType(ret, type)) { - OrtApis::ReleaseTensorTypeAndShapeInfo(ret); - return status; - } - - auto* status = OrtApis::SetDimensions(ret, shape.GetDims().data(), shape.GetDims().size()); - if (status != nullptr) { - OrtApis::ReleaseTensorTypeAndShapeInfo(ret); - return status; - } +OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, onnxruntime::TensorShape shape, + const std::vector* dim_params) { + auto type_and_shape = std::make_unique(); + type_and_shape->type = type; + type_and_shape->shape = std::move(shape); if (dim_params != nullptr) { - ret->dim_params = *dim_params; + type_and_shape->dim_params = *dim_params; } else { // we expect to be being called with a concrete shape so validate that - assert(shape.Size() >= 0); - ret->dim_params.resize(shape.NumDimensions(), ""); + assert(type_and_shape->shape.Size() >= 0); + type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), ""); } - *out = ret; - return nullptr; + return type_and_shape; } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, - const onnxruntime::DataTypeImpl& tensor_data_type, OrtTensorTypeAndShapeInfo** out) { +OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); + ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, shape, nullptr, out); + return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr); } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out) { +OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(onnxruntime::TensorShape shape, const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); @@ -198,13 +190,9 @@ OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const st : type_proto.sparse_tensor_type().elem_type(); ONNXTensorElementDataType type = TensorDataTypeToOnnxRuntimeTensorElementDataType(dtype); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); + ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, shape, dim_params, out); -} - -OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) { - return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out); + return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params); } ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { @@ -219,17 +207,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out const Tensor& tensor = v->Get(); shape = &tensor.Shape(); data_type = tensor.DataType(); + auto ptr = OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + *out = ptr.release(); } else { #if !defined(DISABLE_SPARSE_TENSORS) const SparseTensor& tensor = v->Get(); shape = &tensor.DenseShape(); data_type = tensor.DataType(); + auto ptr = OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + *out = ptr.release(); #endif } - return GetTensorShapeAndType(*shape, *data_type, out); } else { ORT_THROW("Argument is not a tensor"); } + return nullptr; API_IMPL_END } @@ -239,7 +231,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa #if !defined(DISABLE_SPARSE_TENSORS) const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v); const auto& values = sparse_tensor.Values(); - return GetTensorShapeAndType(values.Shape(), *values.DataType(), out); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType()); + *out = ptr.release(); + return nullptr; #else ORT_UNUSED_PARAMETER(v); ORT_UNUSED_PARAMETER(out); @@ -279,7 +273,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format); - return GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), out); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType()); + *out = ptr.release(); + return nullptr; #else ORT_UNUSED_PARAMETER(v); ORT_UNUSED_PARAMETER(indices_format); @@ -309,13 +305,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndices, _In_ const OrtValue* v, ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) { API_IMPL_BEGIN - OrtTypeInfo* type_info; - auto status = OrtTypeInfo::FromOrtValue(*v, &type_info); - if (status != nullptr) - return status; - + auto type_info = OrtTypeInfo::FromOrtValue(*v); *out = type_info->type; - OrtApis::ReleaseTypeInfo(type_info); return nullptr; API_IMPL_END } @@ -334,8 +325,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result *out = nullptr; return nullptr; } - - auto status = OrtTypeInfo::FromOrtValue(*v, out); - return status; + auto ptr = OrtTypeInfo::FromOrtValue(*v); + *out = ptr.release(); + return nullptr; API_IMPL_END } \ No newline at end of file diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index affc3c98a506..283c1de37bda 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -2,14 +2,25 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include "core/framework/tensor_shape.h" #include "core/session/onnxruntime_c_api.h" +namespace ONNX_NAMESPACE { +class TypeProto; +} + +namespace onnxruntime { +class DataTypeImpl; +} + struct OrtTensorTypeAndShapeInfo { public: + using Ptr = std::unique_ptr; + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; onnxruntime::TensorShape shape; // dim_param values. empty string if dim_value or no dim_param was specified. @@ -17,10 +28,24 @@ struct OrtTensorTypeAndShapeInfo { std::vector dim_params; OrtTensorTypeAndShapeInfo() = default; - OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; - OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = delete; + ~OrtTensorTypeAndShapeInfo(); + + Ptr Clone() const { + return std::make_unique(*this); + } + + OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = default; + OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = default; + + // Utils + static Ptr GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, onnxruntime::TensorShape shape, + const std::vector* dim_params); + + static Ptr GetTensorShapeAndType(onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type); - OrtStatus* Clone(OrtTensorTypeAndShapeInfo** out); + static Ptr GetTensorShapeAndType(onnxruntime::TensorShape shape, const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto&); }; constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(int32_t dtype); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 924bc34b2b9e..154e5302382b 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -256,7 +256,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelIn return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); } - return OrtTypeInfo::FromTypeProto(type_proto, type_info); + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; API_IMPL_END } @@ -277,7 +279,9 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelI return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); } - return OrtTypeInfo::FromTypeProto(type_proto, type_info); + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2283a0dbfe96..ae201eebb2fe 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1324,7 +1324,9 @@ static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefLis if (p.second->size() <= index) return OrtApis::CreateStatus(ORT_FAIL, "out of index"); const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto(); - return OrtTypeInfo::FromTypeProto(type_proto, out); + auto type_info = OrtTypeInfo::FromTypeProto(*type_proto); + *out = type_info.release(); + return nullptr; API_IMPL_END } @@ -2703,6 +2705,8 @@ static constexpr OrtApi ort_api_1_to_15 = { &OrtApis::Logger_LogMessage, &OrtApis::Logger_GetLoggingSeverityLevel, &OrtApis::KernelInfoGetConstantInput_tensor, + &OrtApis::CastTypeInfoToOptionalTypeInfo, + &OrtApis::GetOptionalContainedTypeInfo }; // Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 67b425a5e2c9..7cdacbbaf9e7 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -453,4 +453,10 @@ ORT_API_STATUS_IMPL(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger ORT_API_STATUS_IMPL(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + +ORT_API_STATUS_IMPL(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + +ORT_API_STATUS_IMPL(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); } // namespace OrtApis diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 94f94dbac3b9..8e198beaac85 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -392,18 +392,16 @@ ORT_API_STATUS_IMPL(winmla::ModelGetOutputDescription, _In_ const OrtModel* mode ORT_API_STATUS_IMPL(winmla::ModelGetInputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info) { API_IMPL_BEGIN - if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->input_features_[index]->type(), type_info)) { - return status; - } + auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->input_features_[index]->type()); + *type_info = info.release(); return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(winmla::ModelGetOutputTypeInfo, _In_ const OrtModel* model, _In_ size_t index, _Outptr_ OrtTypeInfo** type_info) { API_IMPL_BEGIN - if (auto status = OrtTypeInfo::FromTypeProto(&model->UseModelInfo()->output_features_[index]->type(), type_info)) { - return status; - } + auto info = OrtTypeInfo::FromTypeProto(model->UseModelInfo()->output_features_[index]->type()); + *type_info = info.release(); return nullptr; API_IMPL_END } @@ -744,17 +742,11 @@ ORT_API(void, winmla::ReleaseModel, OrtModel* ptr) { #include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/tensor_type_and_shape.h" -OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape, const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out); - ORT_API_STATUS_IMPL(winmla::CreateTensorTypeInfo, _In_ const int64_t* dim_values, size_t dim_count, ONNXTensorElementDataType type, _Out_ OrtTypeInfo** ort_type_info) { API_IMPL_BEGIN - OrtTensorTypeAndShapeInfo* data = nullptr; auto tensor_shape = onnxruntime::TensorShape(dim_values, dim_count); - auto st = GetTensorShapeAndTypeHelper(type, tensor_shape, nullptr, &data); - if (st != nullptr){ - return st; - } - *ort_type_info = new OrtTypeInfo(ONNX_TYPE_TENSOR, data); + auto type_and_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(type, std::move(tensor_shape), nullptr); + *ort_type_info = OrtTypeInfo::MakePtr(ONNX_TYPE_TENSOR, std::move(type_and_shape)).release(); return nullptr; API_IMPL_END } From 9fb3811798d54075624fc6754c8dc3cadc67c403 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 21 Mar 2023 11:28:42 -0700 Subject: [PATCH 2/9] Regenerate proto code from the current onnx protos Add native methods from the merge Add Test Protobuf data Implement test sequence input loading Optimize Input/Output names conversion and validation Introduce OnnxValue to NamedOnnxValue, rename NativeOnnxTensorMemory Rework ToOrtValue interface Implement ManagedProjections Make sure all required map types are supported Generate input OrtValue using ManagedOnnxType Implement optional support, partial map support. Fix optional issues. Provide details for failing tests Comment out two tests due to invalid test data --- .gitignore | 1 + .../DisposableNamedOnnxValue.shared.cs | 332 ++- .../InferenceSession.shared.cs | 276 +- .../ManagedProjections.shared.cs | 277 ++ .../NamedOnnxValue.shared.cs | 183 +- .../NativeMethods.shared.cs | 15 +- .../NativeOnnxValueHelper.shared.cs | 21 +- .../OrtValue.shared.cs | 15 +- ...ory.shared.cs => OrtValueTensor.shared.cs} | 28 +- ...rosoft.ML.OnnxRuntime.EndToEndTests.csproj | 1 + .../InferenceTest.cs | 16 +- ...crosoft.ML.OnnxRuntime.Tests.Common.csproj | 9 +- .../OnnxData.cs | 1335 ++++++++++ .../OnnxMl.cs | 2 +- .../TestDataLoader.cs | 417 ++- .../InferenceTest.netcore.cs | 296 ++- .../core/framework/onnxruntime_typeinfo.cc | 115 +- .../core/framework/onnxruntime_typeinfo.h | 1 - tools/ci_build/github/Doxyfile_csharp.cfg | 2366 ++++++++++++++++- 19 files changed, 5208 insertions(+), 498 deletions(-) create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs rename csharp/src/Microsoft.ML.OnnxRuntime/{NativeOnnxTensorMemory.shared.cs => OrtValueTensor.shared.cs} (89%) create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs diff --git a/.gitignore b/.gitignore index 61fc6e474b8b..c9e217e135e4 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ onnxprofile_profile_test_*.json /csharp/packages /csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets /csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props +/csharp/**/*.vcxproj.user cmake/external/FeaturizersLibrary/ # Java specific ignores java/.gradle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs index b2a4c2ef47cb..3847863a0dec 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs @@ -3,11 +3,15 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; -using System.Buffers; using System.Collections.Generic; +using System.Reflection; namespace Microsoft.ML.OnnxRuntime { + /// + /// Return immutable collection of results + /// + /// public interface IDisposableReadOnlyCollection : IReadOnlyCollection, IDisposable { @@ -70,24 +74,44 @@ public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable /// Managed object created to represent output value, such as DenseTensor /// List or Dictionary /// - /// Use this to decide what you want to call to fetch data, AsTensor(), AsDictionary() - /// or AsEnumerable() /// Tensor element type if value type is a Tensor /// Object that holds native resources. /// Typically, this is an output OrtValue that holds native memory where Tensor is mapped but may also be /// other things that would need to be disposed by this instance depending on how IOrtValueOwner is implemented. - private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder) - : base(name, value) + private DisposableNamedOnnxValue(string name, Object value, TensorElementType elementType, IOrtValueOwner ortValueHolder) + : base(name, value, OnnxValueType.ONNX_TYPE_TENSOR) { _ortValueHolder = ortValueHolder; - ValueType = onnxValueType; ElementType = elementType; } /// - /// Returns OnnxValueType + /// Ctor for non-tensor values + /// + /// + /// + /// + /// + private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, IOrtValueOwner ortValueHolder) + : base(name, value, onnxValueType) + { + _ortValueHolder = ortValueHolder; + ElementType = TensorElementType.DataTypeMax; + } + + /// + /// Construct from a dictionary /// - public OnnxValueType ValueType { get; } + /// + /// + /// + /// + private DisposableNamedOnnxValue(string name, Object value, MapHelper mapHelper, IOrtValueOwner ortValueHolder) + : base(name, value, mapHelper) + { + _ortValueHolder = ortValueHolder; + ElementType = TensorElementType.DataTypeMax; + } /// /// Only valid if ValueType is Tensor @@ -101,22 +125,70 @@ private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxVa /// to do, as this class maintains a native buffer via _ortValueHolder and the memory will be /// disposed by it. This is the case when we are dealing with an OrtValue that is backed by native memory /// and not by pinned managed memory. + /// + /// This class is generally used for outputs to be created on top of the output OrtValue, + /// but the interface (derived from NamedOnnxValue) allows it to be passed as input and one of the test + /// cases does it. Unless we deprecate and re-do the interface, we must support it. /// /// always set to null /// An instance of OrtValue that does not own underlying memory - internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) + internal override OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryHolder) { - if(_ortValueHolder == null) + if (_ortValueHolder == null) { throw new InvalidOperationException("The instance of this class does not own any OrtValues"); } // PinnedMemoryHandle holds the default value as DisposableNamedOnnxValue // doesn't hold any managed buffer (that needs to be pinned) - pinnedMemoryHandle = null; + memoryHolder = null; // Return non-owning instance of OrtValue return _ortValueHolder.Value; } + /// + /// Generally, this class is created on top of the values that are returned by the model run. + /// So, this method is not expected to be called. However, if it is called (an instance fed as output), + /// it will return the OrtValue that was previously created, since the caller must understand what they are doing. + /// + /// + /// + /// + internal override OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) + { + return InputToOrtValue(metadata, out memoryOwner); + } + + internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue) + { + return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance); + } + + internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator) + { + DisposableNamedOnnxValue result = null; + + IntPtr valueType; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType)); + OnnxValueType onnxValueType = (OnnxValueType)valueType; + switch (onnxValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + result = FromNativeTensor(name, ortValue); + break; + + case OnnxValueType.ONNX_TYPE_SEQUENCE: + result = FromNativeSequence(name, ortValue, allocator); + break; + + case OnnxValueType.ONNX_TYPE_MAP: + result = FromNativeMap(name, ortValue, allocator); + break; + default: + throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported"); + } + return result; + } + /// /// Creates an instance of DisposableNamedOnnxValue and takes ownership of ortValueElement /// on success. @@ -124,7 +196,7 @@ internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) /// name of the value /// underlying OrtValue /// - internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, OrtValue ortValue) + private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue) { DisposableNamedOnnxValue result = null; @@ -146,46 +218,46 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, switch (elemType) { case TensorElementType.Float: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Double: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int32: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt32: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int64: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt64: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.UInt8: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Int8: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.String: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Bool: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.Float16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; case TensorElementType.BFloat16: - result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); + result = FromNativeTensor(name, ortValue); break; default: throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported"); @@ -195,37 +267,6 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, return result; } - internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue) - { - return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance); - } - - internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator) - { - DisposableNamedOnnxValue result = null; - - IntPtr valueType; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType)); - OnnxValueType onnxValueType = (OnnxValueType)valueType; - switch (onnxValueType) - { - case OnnxValueType.ONNX_TYPE_TENSOR: - result = CreateTensorFromOnnxValue(name, ortValue); - break; - - case OnnxValueType.ONNX_TYPE_SEQUENCE: - result = DisposableNamedOnnxValueFromSequence(name, ortValue, allocator); - break; - - case OnnxValueType.ONNX_TYPE_MAP: - result = DisposableNamedOnnxValueFromNativeMap(name, ortValue, allocator); - break; - default: - throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported"); - } - return result; - } - /// /// This method creates an instance of DisposableNamedOnnxValue that has possession of ortValueElement /// native memory Tensor and returns it to the caller. The original ortValueElement argument looses @@ -236,34 +277,26 @@ internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValu /// name of the output /// native tensor /// DisposableNamedOnnxValue instance - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor(string name, OrtValue ortValue) + private static DisposableNamedOnnxValue FromNativeTensor(string name, OrtValue ortValue) { - if (typeof(T) == typeof(string)) + var ortValueTensor = new OrtValueTensor(ortValue); + try { - var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); - try + if (typeof(T) == typeof(string)) { - var dt = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); - } catch(Exception) + var dt = new DenseTensor(ortValueTensor.GetBytesAsStringMemory(), ortValueTensor.Dimensions); + return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor); + } + else { - nativeTensorWrapper.Dispose(); - throw; + DenseTensor dt = new DenseTensor(ortValueTensor.Memory, ortValueTensor.Dimensions); + return new DisposableNamedOnnxValue(name, dt, ortValueTensor.ElementType, ortValueTensor); } } - else + catch (Exception) { - NativeOnnxTensorMemory nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); - try - { - DenseTensor dt = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); - } - catch (Exception) - { - nativeTensorWrapper.Dispose(); - throw; - } + ortValueTensor.Dispose(); + throw; } } @@ -275,7 +308,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor /// ortValueElement that has native sequence /// used allocator /// DisposableNamedOnnxValue - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator) + private static DisposableNamedOnnxValue FromNativeSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator) { DisposableNamedOnnxValue result = null; IntPtr count; @@ -295,8 +328,8 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str } // NativeOrtValueCollectionOwner will take ownership of ortValueSequence and will make sure sequence // is also disposed. - var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence); - result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, TensorElementType.DataTypeMax, nativeCollectionManager); + var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence); + result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, nativeCollectionManager); } catch (Exception) { @@ -314,7 +347,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(str /// This function does not take ownership of the map as it we copy all keys an values into a dictionary. We let the caller dispose of it /// /// DisposableNamedOnnxValue - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator) + private static DisposableNamedOnnxValue FromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator) { DisposableNamedOnnxValue result = null; // Map processing is currently not recursing. It is assumed to contain @@ -323,44 +356,91 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st // not mapped for client consumption. using (var cleanUpList = new DisposableList()) { - // Take possession of the map ortValueElement IntPtr nativeOnnxValueMapKeys = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 0, allocator.Pointer, out nativeOnnxValueMapKeys)); var ortValueKeys = new OrtValue(nativeOnnxValueMapKeys); cleanUpList.Add(ortValueKeys); + var typeAndShape = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); + TensorElementType keyElemType; + try + { + IntPtr el_type; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); + keyElemType = (TensorElementType)el_type; + } + finally + { + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); + } + IntPtr nativeOnnxValueMapValues = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 1, allocator.Pointer, out nativeOnnxValueMapValues)); var ortValueValues = new OrtValue(nativeOnnxValueMapValues); cleanUpList.Add(ortValueValues); - IntPtr typeAndShape = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); - TensorElementType elemType = TensorElementType.DataTypeMax; + typeAndShape = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapValues, out typeAndShape)); + TensorElementType valueElemType; try { IntPtr el_type; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); - elemType = (TensorElementType)el_type; + valueElemType = (TensorElementType)el_type; } finally { NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); } - /// XXX: This code always assumes that the value type is float and makes no checks - /// similar to that of the key. Also Map type in general can also be another sequence or map, - /// not just a tensor - switch (elemType) + if (valueElemType != TensorElementType.Float) + { + throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value element type: {valueElemType} not supported"); + } + + switch (keyElemType) { case TensorElementType.Int64: - result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + switch (valueElemType) + { + case TensorElementType.Float: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Double: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Int64: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.String: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + default: + break; + } break; case TensorElementType.String: - result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + switch (valueElemType) + { + case TensorElementType.Float: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Double: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.Int64: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + case TensorElementType.String: + result = FromNativeMapElements(name, ortValueMap, ortValueKeys, ortValueValues); + break; + default: + break; + } break; default: - throw new NotSupportedException("Map of element type: " + elemType + " is not supported"); + throw new NotSupportedException("Map key type: " + keyElemType + " is not supported"); } } return result; @@ -381,40 +461,78 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st /// tensor with map keys. /// tensor with map values /// instance of DisposableNamedOnnxValue with Dictionary - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapElements(string name, + private static DisposableNamedOnnxValue FromNativeMapElements(string name, OrtValue ortValueMap, OrtValue ortValueTensorKeys, OrtValue ortValueTensorValues) { - using (var nativeTensorWrapperValues = new NativeOnnxTensorMemory(ortValueTensorValues)) + var listOfKeysValues = new DisposableList(); + var collOwner = new NativeOrtValueCollectionOwner(ortValueMap, listOfKeysValues); + try { - var denseTensorValues = new DenseTensor(nativeTensorWrapperValues.Memory, nativeTensorWrapperValues.Dimensions); + var tensorKeys = new OrtValueTensor(ortValueTensorKeys); + listOfKeysValues.Add(ortValueTensorKeys); + var tensorValues = new OrtValueTensor(ortValueTensorValues); + listOfKeysValues.Add(ortValueTensorValues); + MapHelper mapHelper = null; if (typeof(K) == typeof(string)) { - var map = new Dictionary(); - using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + var denseTensorKeys = new DenseTensor(tensorKeys.GetBytesAsStringMemory(), tensorKeys.Dimensions); + + if (typeof(V) == typeof(string)) { - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions); for (var i = 0; i < denseTensorKeys.Length; i++) { map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); } - return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); + } + else + { + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); } } else { - var map = new Dictionary(); - using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + var denseTensorKeys = new DenseTensor(tensorKeys.Memory, tensorKeys.Dimensions); + if (typeof(V) == typeof(string)) { - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); + var map = new Dictionary(); + var denseTensorValues = new DenseTensor(tensorValues.GetBytesAsStringMemory(), tensorValues.Dimensions); for (var i = 0; i < denseTensorKeys.Length; i++) { map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); } - return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); + } + else + { + var denseTensorValues = new DenseTensor(tensorValues.Memory, tensorValues.Dimensions); + var map = new Dictionary(); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + mapHelper = new MapHelper(denseTensorKeys, denseTensorValues); + return new DisposableNamedOnnxValue(name, map, mapHelper, collOwner); } } } + catch (Exception) + { + collOwner.Dispose(); + throw; + } } #region IDisposable Support @@ -425,7 +543,7 @@ private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(st /// true if invoked by Dispose() protected virtual void Dispose(bool disposing) { - if(_disposed) + if (_disposed) { return; } @@ -448,9 +566,7 @@ protected virtual void Dispose(bool disposing) /// public void Dispose() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. Dispose(true); - GC.SuppressFinalize(this); } #endregion diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 91366db5a245..7c5a8ea1e93f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -8,6 +8,7 @@ using System.Linq; using Microsoft.ML.OnnxRuntime.Tensors; using System.Buffers; +using System.Diagnostics; namespace Microsoft.ML.OnnxRuntime { @@ -31,11 +32,21 @@ public class InferenceSession : IDisposable /// private Dictionary _inputMetadata; + /// + /// Ordered list of input names + /// + private List _inputNames; + /// /// Dictionary that represent output metadata /// private Dictionary _outputMetadata; + /// + /// Ordered list of output names + /// + private List _outputNames; + /// /// Dictionary that represents overridableInitializers metadata /// @@ -163,6 +174,11 @@ public InferenceSession(byte[] model, SessionOptions options) } } + /// + /// Ordered list of input names that can be accessed by index; + /// + public IReadOnlyList InputNames { get { return _inputNames; } } + /// /// Metadata regarding the output nodes, keyed by output names /// @@ -174,6 +190,11 @@ public InferenceSession(byte[] model, SessionOptions options) } } + /// + /// Ordered list of output names that can be accessed by index. + /// + public IReadOnlyList OutputNames { get { return _outputNames; } } + /// /// Metadata regarding the overridable initializers, keyed by node names /// @@ -203,7 +224,8 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. - public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames) + public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, + IReadOnlyCollection outputNames) { return Run(inputs, outputNames, _builtInRunOptions); } @@ -215,13 +237,15 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. - public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) + public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, + IReadOnlyCollection outputNames, + RunOptions options) { using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); return CreateDisposableResult(ortValues, outputNames); @@ -238,9 +262,7 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl IReadOnlyCollection inputNames, IReadOnlyCollection inputValues) { - string[] outputNames = new string[_outputMetadata.Count]; - _outputMetadata.Keys.CopyTo(outputNames, 0); - return Run(inputNames, inputValues, outputNames, _builtInRunOptions); + return Run(inputNames, inputValues, _outputNames, _builtInRunOptions); } /// @@ -279,9 +301,9 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); @@ -336,11 +358,11 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( @@ -371,7 +393,7 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl } /// - /// + /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. @@ -386,11 +408,11 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl { using (var cleanupList = new DisposableList()) { - var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); - var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); - var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList); + var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, @@ -444,11 +466,11 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); - var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, LookupInputMetadata, cleanupList); + var inputValuesArray = GetOrtValuesHandles(inputs, LookupInputMetadata, ExtractOrtValueForInput, cleanupList); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, LookupOutputMetadata, cleanupList); var outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( @@ -482,7 +504,7 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl } /// - /// + /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. @@ -505,12 +527,12 @@ public IDisposableReadOnlyCollection Run(IReadOnlyColl using (var cleanupList = new DisposableList()) { // prepare inputs - var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); + var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, LookupInputMetadata, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs - var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); - var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); + var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, LookupOutputMetadata, cleanupList); + var outputValuesArray = GetOrtValuesHandles(outputs, LookupOutputMetadata, ExtractOrtValueForOutput, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, @@ -619,22 +641,96 @@ public string EndProfiling() // Delegate for string extraction from an arbitrary input/output object private delegate string NameExtractor(TInput input); + // delegate to fetch input/output OrtValue + private delegate OrtValue OrtValueExtractor(NamedOnnxValue value, NodeMetadata metadata, out IDisposable memOwner); + + // Delegate to lookup metadata for input/initializers/output + private delegate NodeMetadata MetadataLookup(string nodeName); + + /// + /// Checks if the name is a known input or overridable initializer name + /// and if so, returns metadata for it. + /// metadata + /// + /// + /// NodeMetadata for the nodeName + /// + private NodeMetadata LookupInputMetadata(string nodeName) + { + NodeMetadata meta; + if (!_inputMetadata.TryGetValue(nodeName, out meta) && + !_overridableInitializerMetadata.TryGetValue(nodeName, out meta)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input/output name: '{nodeName}' is not in the metadata"); + } + return meta; + } + + /// + /// Checks if the nodeName is a known output name and if so returns metadata for it. + /// + /// + /// + /// + private NodeMetadata LookupOutputMetadata(string nodeName) + { + NodeMetadata meta; + if (!_outputMetadata.TryGetValue(nodeName, out meta)) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input/output name: '{nodeName}' is not in the metadata"); + } + return meta; + } + + /// + /// Fetches/creates OrtValue for the content of the input + /// + /// + /// + /// + /// + private static OrtValue ExtractOrtValueForInput(NamedOnnxValue input, NodeMetadata metadata, out IDisposable memOwner) + { + return input.InputToOrtValue(metadata, out memOwner); + } + + /// + /// Fetches/Creates OrtValue for output + /// + /// + /// + /// + /// May return null if the onnx value type does not support pre-creation of output OrtValues + private static OrtValue ExtractOrtValueForOutput(NamedOnnxValue output, NodeMetadata metadata, out IDisposable memOwner) + { + return output.OutputToOrtValue(metadata, out memOwner); + } + /// /// Run helper /// - /// names to convert to zero terminated utf8 and pin + /// names to convert to zero terminated utf8 and pin + /// extractor functor that helps extracting names from inputs + /// inputs/outputs metadata /// list to add pinned memory to for later disposal /// - private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtractor extractor, + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection values, NameExtractor nameExtractor, + MetadataLookup metaLookup, DisposableList cleanupList) { - var result = new IntPtr[inputs.Count]; - for (int i = 0; i < inputs.Count; ++i) + cleanupList.Capacity += values.Count; + var result = new IntPtr[values.Count]; + for (int i = 0; i < values.Count; ++i) { - var name = extractor(inputs.ElementAt(i)); - var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name); - var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned)); - result[i] = pinnedHandle.Pointer; + var name = nameExtractor(values.ElementAt(i)); + NodeMetadata meta = metaLookup(name); + var utf8Name = meta.ZeroTerminatedUtf8Name; + Debug.Assert(utf8Name != null); + var pinnedHandle = new Memory(utf8Name).Pin(); + unsafe + { + result[i] = (IntPtr)pinnedHandle.Pointer; + } cleanupList.Add(pinnedHandle); } return result; @@ -642,28 +738,41 @@ public string EndProfiling() /// /// This function obtains ortValues for NamedOnnxValue. - /// The problem with NamedOnnxValue is that it does not contain any Onnx (OrtValue) - /// so calling ToOrtValue creates a new instance of OrtValue that needs to be disposed. + /// The problem with NamedOnnxValue is that it is not disposable and can not contain any disposable items. + /// so calling InputToOrtValue creates a new instance of OrtValue that needs to be disposed. /// The deriving object DisposableNamedValue actually contains and owns OrtValue and it returns /// it. /// - /// - /// + /// a collection of NamedOnnxValues + /// Metadata lookup function (input/initializers/output) + /// list to cleanup in an exception safe manner /// - private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, DisposableList cleanupList) + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, MetadataLookup metaLookup, + OrtValueExtractor ortValueExtractor, + DisposableList cleanupList) { + cleanupList.Capacity += values.Count * 2; IntPtr[] result = new IntPtr[values.Count]; - for (int inputIndex = 0; inputIndex < values.Count; ++inputIndex) + for (int valueIndex = 0; valueIndex < values.Count; ++valueIndex) { - var input = values.ElementAt(inputIndex); - MemoryHandle? memHandle; - var ortValue = input.ToOrtValue(out memHandle); - if (memHandle.HasValue) + var value = values.ElementAt(valueIndex); + var meta = metaLookup(value.Name); + var ortValue = ortValueExtractor(value, meta, out IDisposable memHolder); + if (memHolder != null) { - cleanupList.Add(memHandle); + cleanupList.Add(memHolder); + } + if (ortValue != null) + { + if (ortValue.IsOwned) + cleanupList.Add(ortValue); + + result[valueIndex] = ortValue.Handle; + } + else + { + result[valueIndex] = IntPtr.Zero; } - cleanupList.Add(ortValue); - result[inputIndex] = ortValue.Handle; } return result; } @@ -687,7 +796,8 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v private DisposableList RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames, DisposableList cleanupList) { - var ortValues = new DisposableList(outputNames.Length); + cleanupList.Capacity += 1; + var ortValues = new DisposableList(outputNames.Length + 1); cleanupList.Add(ortValues); IntPtr[] outputValuesArray = new IntPtr[outputNames.Length]; @@ -717,8 +827,7 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v { for (int i = 0; i < ortValues.Count; i++) { - var ortValue = ortValues[i]; - result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValue)); + result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValues[i])); } } catch (OnnxRuntimeException) @@ -820,40 +929,53 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options) _nativeHandle = session; try { - // Initialize input/output metadata - _inputMetadata = new Dictionary(); - _outputMetadata = new Dictionary(); - _overridableInitializerMetadata = new Dictionary(); // get input count UIntPtr inputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); // get all the input names and metadata + _inputMetadata = new Dictionary((int)inputCount); + _inputNames = new List((int)inputCount); + for (ulong i = 0; i < (ulong)inputCount; i++) { - var iname = GetInputName(i); - _inputMetadata[iname] = GetInputMetadata(i); + var inputMeta = GetInputMetadata(i); + var iname = GetInputName(i, out byte[] utf8); + _inputNames.Add(iname); + inputMeta.ZeroTerminatedUtf8Name = utf8; + _inputMetadata[iname] = inputMeta; } // get output count UIntPtr outputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); // get all the output names and metadata + _outputMetadata = new Dictionary((int)outputCount); + _outputNames = new List((int)outputCount); + for (ulong i = 0; i < (ulong)outputCount; i++) { - _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); + var outputMeta = GetOutputMetadata(i); + var oname = GetOutputName(i, out byte[] utf8); + _outputNames.Add(oname); + outputMeta.ZeroTerminatedUtf8Name = utf8; + _outputMetadata[oname] = outputMeta; } // get overridable initializer count UIntPtr initilaizerCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount)); + _overridableInitializerMetadata = new Dictionary((int)initilaizerCount); // get all the overridable initializer names and metadata for (ulong i = 0; i < (ulong)initilaizerCount; i++) { - _overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i); + var meta = GetOverridableInitializerMetadata(i); + var iname = GetOverridableInitializerName(i, out byte[] utf8); + meta.ZeroTerminatedUtf8Name = utf8; + _overridableInitializerMetadata[iname] = meta; } // set profiling's start time UIntPtr startTime = UIntPtr.Zero; @@ -861,7 +983,7 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options) out startTime)); _profilingStartTimeNs = (ulong)startTime; } - catch (OnnxRuntimeException) + catch (Exception) { if (_nativeHandle != IntPtr.Zero) { @@ -875,11 +997,11 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options) } - private string GetOutputName(ulong index) + private string GetOutputName(ulong index, out byte[] utf8) { + string str; var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; - string str = null; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputName( _nativeHandle, (UIntPtr)index, @@ -888,15 +1010,15 @@ private string GetOutputName(ulong index) using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } - private string GetInputName(ulong index) + private string GetInputName(ulong index, out byte[] utf8) { - string str = null; + string str; var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputName( @@ -907,14 +1029,14 @@ private string GetInputName(ulong index) using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } - private string GetOverridableInitializerName(ulong index) + private string GetOverridableInitializerName(ulong index, out byte[] utf8) { - string str = null; + string str; var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerName( @@ -924,7 +1046,7 @@ private string GetOverridableInitializerName(ulong index) out nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { - str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); + NativeOnnxValueHelper.StringAndUtf8FromNative(nameHandle, out str, out utf8); } return str; } @@ -1198,7 +1320,7 @@ internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, str ElementTypeInfo = TensorBase.GetElementTypeInfo(elementType); if (ElementTypeInfo == null) { - throw new ArgumentException("Unregistered TensorElementType value of: " + elementType.ToString()); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unregistered TensorElementType value of: " + elementType.ToString()); } ElementDataType = elementType; Dimensions = dimensions; @@ -1349,7 +1471,7 @@ private void CheckTensor() { if (!IsTensor) { - throw new InvalidOperationException("OnnxValueType must either be a tensor or sparse tensor"); + throw new OnnxRuntimeException(ErrorCode.Fail, "OnnxValueType must either be a tensor or sparse tensor"); } } @@ -1362,7 +1484,7 @@ public MapMetadata AsMapMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_MAP) { - throw new InvalidOperationException("Instance does not contain Map metadata"); + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Map metadata"); } return _metadata as MapMetadata; } @@ -1376,7 +1498,7 @@ public SequenceMetadata AsSequenceMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_SEQUENCE) { - throw new InvalidOperationException("Instance does not contain Sequence metadata"); + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Sequence metadata"); } return _metadata as SequenceMetadata; } @@ -1392,7 +1514,7 @@ public OptionalMetadata AsOptionalMetadata() { if (OnnxValueType != OnnxValueType.ONNX_TYPE_OPTIONAL) { - throw new InvalidOperationException("Instance does not contain Optional metadata"); + throw new OnnxRuntimeException(ErrorCode.Fail, "Instance does not contain Optional metadata"); } return _metadata as OptionalMetadata; } @@ -1403,6 +1525,15 @@ public OptionalMetadata AsOptionalMetadata() /// A value of OnnxValueType enum public OnnxValueType OnnxValueType { get; } + /// + /// Zero terminated UTF-8 name of the input/output + /// Present only on the top-level instance + /// metadata dictionary entries. + /// + /// Used to avoid utf8 conversions on every run and associated allocations + /// + public byte[] ZeroTerminatedUtf8Name { get; set; } + /// /// Tensor shape valid only if this is a Tensor. /// Preserved for API compatibility @@ -1457,6 +1588,9 @@ public TensorElementType ElementDataType } } + /// + /// Convinience method to check for string + /// public bool IsString { get diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs new file mode 100644 index 000000000000..664ba21cfd1b --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// The class helps to feed the NamedOnnxValue as inference input. + /// It projects managed classes to OrtValues so they can be consumed + /// by the native onnxruntime library. if possible, it will avoid copying data. + /// The NamedOnnxValue can be a tensor, sequence or map. + /// For recursive structures, create nested NamedOnnxValue instances. + /// For example, a sequence instance would contain a list of NamedOnnxValue instances + /// that in turn may represent tensors or other ONNX values. + /// + internal class ManagedTypeProjection : IDisposable + { + readonly DisposableList _disposables; + readonly OrtValue _ortValue; + bool _disposed = false; + + /// + /// Provides access to non-owning instance of OrtValue + /// + /// Provides access to the OrtValue to be used as input + internal OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } } + + /// + /// Constructor to create an input OrtValue projection from managed data + /// + /// + /// + /// + internal ManagedTypeProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata) + { + int requiredCapacity = 32; + var disposables = new DisposableList(requiredCapacity); + try + { + _ortValue = CreateDispatchProjection(namedOnnxValue, metadata, disposables); + } + catch (Exception) + { + disposables.Dispose(); + throw; + } + _disposables = disposables; + } + + /// + /// Dispatches the creation of the projection + /// + /// + /// + /// + /// + private OrtValue CreateDispatchProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables) + { + OrtValue result; + + NodeMetadata meta = metadata; + // Use element meta to create types + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) + { + meta = metadata.AsOptionalMetadata().ElementMeta; + } + + if (namedOnnxValue.ValueType != meta.OnnxValueType) + { + throw new OnnxRuntimeException(ErrorCode.RuntimeException, + $"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" + + $" expected: {meta.OnnxValueType} after optional type adjustment"); + } + + switch (namedOnnxValue.ValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + result = CreateTensorProjection(namedOnnxValue, meta, disposables); + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + result = CreateSequenceProjection(namedOnnxValue, meta, disposables); + break; + case OnnxValueType.ONNX_TYPE_MAP: + result = CreateMapProjection(namedOnnxValue, meta, disposables); + break; + default: + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types"); + } + return result; + } + + /// + /// The function creates OrtValue objects for each element of the sequence + /// and then creates an OrtValue for the whole sequence. + /// + /// NamedOnnxValue containing a IEnumeralbe + /// sequence metadata + /// cleanup list + /// + /// + private OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList disposables) + { + OrtValue result = null; + var elementMeta = metadata.AsSequenceMetadata().ElementMeta; + var elementOnnxValue = elementMeta.OnnxValueType; + var seqContainer = namedOnnxValue.AsEnumerable(); + + if (seqContainer is null) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements"); + } + + int capacity = 0; + + if (seqContainer is ICollection) + { + capacity = ((ICollection)seqContainer).Count; + } + + // Record all the ortValues belonging to the sequence locally + var sequenceOrtValues = new List(capacity); + foreach (var element in seqContainer) + { + if (elementOnnxValue != element.ValueType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}"); + } + + sequenceOrtValues.Add(CreateDispatchProjection(element, elementMeta, disposables)); + } + + IntPtr[] ortValHandles = new IntPtr[sequenceOrtValues.Count]; + for (int i = 0; i < sequenceOrtValues.Count; i++) + { + ortValHandles[i] = sequenceOrtValues[i].Handle; + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, + (UIntPtr)sequenceOrtValues.Count, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE, out IntPtr sequenceHandle)); + result = new OrtValue(sequenceHandle); + disposables.Add(result); + + return result; + } + + /// + /// Creates map projection. Since we support only primitive types in maps + /// we map two tensors (keys and values) + /// + /// + /// + /// + /// OrtValue + /// + private OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables) + { + OrtValue result = null; + var mapMeta = elementMeta.AsMapMetadata(); + Debug.Assert(mapMeta != null); + // Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps + + var mapValuesMeta = mapMeta.ValueMetadata; + if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Node: {node.Name} onnxruntime only supports maps with primitive types values"); + } + + + var keys = node.GetDictionaryKeys(); + var ortValueKeys = OrtValue.CreateFromTensorObject(keys, + out MemoryHandle? memoryHandleKeys, out TensorElementType elementTypeKeys); + disposables.Add(ortValueKeys); + + if (memoryHandleKeys.HasValue) + { + disposables.Add(memoryHandleKeys); + } + + if (elementTypeKeys != mapMeta.KeyDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}"); + } + + var values = node.GetDictionaryValues(); + var ortValueValues = OrtValue.CreateFromTensorObject(values, + out MemoryHandle? memoryHandleValues, out TensorElementType elementTypeValues); + + disposables.Add(ortValueValues); + if (memoryHandleValues.HasValue) + { + disposables.Add(memoryHandleValues); + } + + if (elementTypeValues != mapValuesMeta.ElementDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}"); + } + + // Create Map OrtValue + IntPtr[] ortValHandles = { ortValueKeys.Handle, ortValueValues.Handle }; + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, (UIntPtr)2, + (IntPtr)OnnxValueType.ONNX_TYPE_MAP, out IntPtr ortValueMap)); + result = new OrtValue(ortValueMap); + disposables.Add(result); + return result; + } + + + /// + /// This pins memory that is contained within DenseTensor. + /// + /// NodeOnnxValue containing DenseTensor + /// + /// cleanup list + /// + /// + private OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList disposables) + { + var ortValue = OrtValue.CreateFromTensorObject(node.Value, + out MemoryHandle? memoryHandle, out TensorElementType elementType); + disposables.Add(ortValue); + + if (memoryHandle.HasValue) + { + disposables.Add(memoryHandle); + } + + if (elementType != elementMeta.ElementDataType) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, + $"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}"); + } + + return ortValue; + } + + #region IDisposable + /// + /// IDisposable implementation + /// + /// true if invoked by Dispose() + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + // dispose managed state (managed objects). + if (disposing) + { + _disposables.Dispose(); + } + _disposed = true; + } + + + public void Dispose() + { + Dispose(true); + } + + #endregion IDisposable + } +} + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs index 233e5fa9af87..9334313e87aa 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs @@ -5,38 +5,98 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; namespace Microsoft.ML.OnnxRuntime { /// - /// The class associates a name with an Object. Currently it supports Tensor - /// as possible objects. The name of the class is a misnomer, it does not hold any - /// Onnx values. + /// The class holds keys and values for the dictionary + /// + internal class MapHelper + { + internal MapHelper(object keys, object values) + { + Keys = keys; + Values = values; + } + internal Object Keys { get; } // DenseTensor + internal Object Values { get; } // DenseTensor + } + + /// + /// The class associates a name with an Object. + /// The name of the class is a misnomer, it does not hold any Onnx values, + /// just managed representation of them. + /// + /// The class is currently used as both inputs and outputs. Because it is non- + /// disposable, it can not hold on to any native objects. + /// + /// When used as input, we temporarily create OrtValues that map managed inputs + /// directly. Thus we are able to avoid copying. + /// + /// For outputs, tensor buffers works the same as input, providing it matches + /// the expected output shape. For other types (maps and sequences, we create a copy of the data). + /// This is because, the class is not Disposable and it is a public interface, thus it can not own + /// the underlying OrtValues that must be destroyed before Run() returns. + /// + /// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods. + /// This provides access to the native memory and avoids copying. + /// + /// It is a recursive structure that may contain Tensors (base case) + /// Other sequences and maps. Although the OnnxValueType is exposed, + /// the caller is supposed to know the actual data type contained. + /// For that one will need to consult model metadata. + /// /// public class NamedOnnxValue { /// /// Managed Tensor, Dictionary or IList /// - protected Object _value; + private Object _value; /// /// Name of the instance, model input/output /// - protected string _name; + private string _name; + + private MapHelper _mapHelper; // used for maps, otherwise null /// /// Constructs an instance of NamedOnnxValue and represents - /// a model input to an inference session. It also represents a modle output - /// when serves as a base for DisposablenamedOnnxvalue + /// a model input to an inference session. /// /// input/output name /// Object that may be a tensor, Dictionary, IList + [Obsolete("This the constructor with valueType or static factory methods")] protected NamedOnnxValue(string name, Object value) { _name = name; _value = value; + ValueType = OnnxValueType.ONNX_TYPE_UNKNOWN; + } + + internal NamedOnnxValue(string name, Object value, OnnxValueType valueType) + { + _name = name; + _value = value; + ValueType = valueType; + } + + internal NamedOnnxValue(string name, Object value, MapHelper helper) + { + _name = name; + _value = value; + ValueType = OnnxValueType.ONNX_TYPE_MAP; + _mapHelper = helper; } + /// + /// Onnx Value Type if known. In general, NamedOnnxValue is able to contain + /// arbitrary objects. + /// + public OnnxValueType ValueType { get; } + /// /// This is a factory method that instantiates NamedOnnxValue /// and associated name with an instance of a Tensor @@ -47,7 +107,40 @@ protected NamedOnnxValue(string name, Object value) /// public static NamedOnnxValue CreateFromTensor(string name, Tensor value) { - return new NamedOnnxValue(name, value); + return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR); + } + + /// + /// This is a factory method that instantiates NamedOnnxValue. + /// It would contain a sequence of elements + /// + /// + /// + /// + public static NamedOnnxValue CreateFromSequence(string name, IEnumerable value) + { + return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_SEQUENCE); + } + + /// + /// This is a factory method that instantiates NamedOnnxValue. + /// + /// Keys type + /// Values type + /// + /// + /// new instance of NamedOnnxValue + public static NamedOnnxValue CreateFromMap(string name, IDictionary value) + { + // The order in which Keys and Values are unspecified, + // but it is guaranteed to be the same order + // These tensors are 1-D + var keysMemory = new Memory(value.Keys.ToArray()); + var keysTensor = new DenseTensor(keysMemory, new int[1] { keysMemory.Length }); + + var valuesMemory = new Memory(value.Values.ToArray()); + var valuesTensor = new DenseTensor(valuesMemory, new int[1] { valuesMemory.Length }); + return new NamedOnnxValue(name, value, new MapHelper(keysTensor, valuesTensor)); } /// @@ -94,15 +187,83 @@ public IEnumerable AsEnumerable() } /// - /// Pin the underlying memory and create an instance of OrtValue + /// Pin the underlying memory and create an instance of OrtValue containing a tensor /// based on the pinned managed memory. The caller is responsible for Disposing /// both OrtValue and pinnedMemoryHandle /// /// dispose after returned OrtValus is disposed /// an instance of OrtValue. The lifespan of OrtValue must overlap pinnedMemoryHandle - internal virtual OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) + internal virtual OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) + { + var projection = new ManagedTypeProjection(this, metadata); + memoryOwner = projection; + return projection.Value; + } + + /// + /// Produces an output value for outputs. This produces an output value + /// only for tensors or optional types that can contain a tensor. + /// For all others we return a null, letting ORT to create an output value. + /// + /// + /// + /// + internal virtual OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposable memoryOwner) { - return OrtValue.CreateFromTensorObject(_value, out pinnedMemoryHandle, out TensorElementType elementType); + // For NamedOnnxValue for output we only allow to produce OrtValue for tensors + // or optional type that may contain a tensor + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) + { + var projection = new ManagedTypeProjection(this, metadata); + memoryOwner = projection; + return projection.Value; + } + + if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) + { + var meta = metadata.AsOptionalMetadata().ElementMeta; + if (meta.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) + { + var projection = new ManagedTypeProjection(this, meta); + memoryOwner = projection; + return projection.Value; + } + } + memoryOwner = null; + return null; + } + + /// + /// This method is used internally to feed dictionary keys + /// to create an OrtValue for map keys + /// + /// + /// DenseTensor" + internal Object GetDictionaryKeys() + { + if (ValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary"); + } + + Debug.Assert(_mapHelper != null); + return _mapHelper.Keys; + } + + /// + /// + /// + /// + /// DenseTensor" + internal Object GetDictionaryValues() + { + if (ValueType != OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.Fail, "This NamedOnnxValue instance does not contain a dictionary"); + } + + Debug.Assert(_mapHelper != null); + return _mapHelper.Values; } // may expose different types of getters in future diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index fcba15a7b280..c92db2afd484 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -278,6 +278,12 @@ public struct OrtApi public IntPtr UpdateDnnlProviderOptions; public IntPtr GetDnnlProviderOptionsAsString; public IntPtr ReleaseDnnlProviderOptions; + public IntPtr KernelInfo_GetNodeName; + public IntPtr KernelInfo_GetLogger; + public IntPtr KernelContext_GetLogger; + public IntPtr Logger_LogMessage; + public IntPtr Logger_GetLoggingSeverityLevel; + public IntPtr KernelInfoGetConstantInput_tensor; public IntPtr CastTypeInfoToOptionalTypeInfo; public IntPtr GetOptionalContainedTypeInfo; } @@ -406,9 +412,10 @@ static NativeMethods() OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection)); OrtGetValue = (DOrtGetValue)Marshal.GetDelegateForFunctionPointer(api_.GetValue, typeof(DOrtGetValue)); + OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount)); + OrtCreateValue = (DOrtCreateValue)Marshal.GetDelegateForFunctionPointer(api_.CreateValue, typeof(DOrtCreateValue)); OrtGetValueType = (DOrtGetValueType)Marshal.GetDelegateForFunctionPointer(api_.GetValueType, typeof(DOrtGetValueType)); OrtGetOnnxTypeFromTypeInfo = (DOrtGetOnnxTypeFromTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetOnnxTypeFromTypeInfo, typeof(DOrtGetOnnxTypeFromTypeInfo)); - OrtGetValueCount = (DOrtGetValueCount)Marshal.GetDelegateForFunctionPointer(api_.GetValueCount, typeof(DOrtGetValueCount)); OrtGetTypeInfo = (DOrtGetTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTypeInfo, typeof(DOrtGetTypeInfo)); OrtCreateTensorAsOrtValue = (DOrtCreateTensorAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorAsOrtValue, typeof(DOrtCreateTensorAsOrtValue)); OrtCreateTensorWithDataAsOrtValue = (DOrtCreateTensorWithDataAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorWithDataAsOrtValue, typeof(DOrtCreateTensorWithDataAsOrtValue)); @@ -1600,6 +1607,12 @@ internal class NativeLib public static DOrtGetValueCount OrtGetValueCount; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr/*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, + UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); + + public static DOrtCreateValue OrtCreateValue; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTypeInfo(IntPtr /*(OrtValue*)*/ value, IntPtr /*(OrtValue**)*/ typeInfo); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index 0a1ae1912a1e..704eb4d14222 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -84,9 +84,28 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8) return Encoding.UTF8.GetString(buffer, 0, buffer.Length); } + /// + /// Reads UTF-8 string from native C zero terminated string, + /// converts it to C# UTF-16 string and returns both C# string and utf-8 + /// bytes as a zero terminated array, suitable for use as a C-string + /// + /// input + /// C# UTF-16 string + /// UTF-8 bytes in a managed buffer, zero terminated + internal static void StringAndUtf8FromNative(IntPtr nativeUtf8, out string str, out byte[] utf8) + { + // .NET 8.0 has Marshal.PtrToStringUTF8 that does the below + int len = 0; + while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len; + utf8 = new byte[len + 1]; + Marshal.Copy(nativeUtf8, utf8, 0, len); + utf8[len] = 0; + str = Encoding.UTF8.GetString(utf8, 0, len); + } + internal static string StringFromUtf8Span(ReadOnlySpan utf8Span) { - // For now we have to copy into byte[], this produces a copy + // XXX: For now we have to copy into byte[], this produces a copy // Converting from span is available in later versions var utf8Bytes = utf8Span.ToArray(); return Encoding.UTF8.GetString(utf8Bytes, 0, utf8Bytes.Length); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index c279c31a3770..a385f0a24985 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -346,18 +346,20 @@ out valueHandle // fill the native tensor, using GetValue(index) from the Tensor var len = tensor.Length; var nativeStrings = new IntPtr[len]; - using (var pinnedHandles = new DisposableList((int)len)) + using (var pinnedHandles = new DisposableList((int)len)) { for (int i = 0; i < len; i++) { var utf8str = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(tensor.GetValue(i)); - var gcHandle = GCHandle.Alloc(utf8str, GCHandleType.Pinned); - nativeStrings[i] = gcHandle.AddrOfPinnedObject(); - pinnedHandles.Add(new PinnedGCHandle(gcHandle)); + var pinnedUtf8 = new Memory(utf8str).Pin(); + unsafe + { + nativeStrings[i] = (IntPtr)pinnedUtf8.Pointer; + } + pinnedHandles.Add(pinnedUtf8); } - using (var pinnedStrings = new PinnedGCHandle(GCHandle.Alloc(nativeStrings, GCHandleType.Pinned))) - NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len)); } } catch (OnnxRuntimeException) @@ -381,6 +383,7 @@ protected override bool ReleaseHandle() if (IsOwned) { NativeMethods.OrtReleaseValue(handle); + IsOwned = false; } // Prevent use after disposal handle = IntPtr.Zero; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs similarity index 89% rename from csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs rename to csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs index ce339b6a528c..29ae3ab2be30 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValueTensor.shared.cs @@ -23,13 +23,14 @@ internal interface IOrtValueOwner : IDisposable /// This class is used in conjunction with DisposableNamedOnnxValue to /// own native collection OrtValue and dispose of it along with any DisposableNamedOnnxValues /// - internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable + internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable + where T:IDisposable { private OrtValue _ortValue; - private DisposableList _disposables; + private DisposableList _disposables; bool _disposed = false; - internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables) + internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables) { Debug.Assert(ortValue.IsOwned); _ortValue = new OrtValue(ortValue.Disown()); @@ -80,19 +81,24 @@ public void Dispose() /// /// This helper class owns the underlying OrtValue that is assumed to be a Tensor, /// it does not support any other ortValues and caches Tensor properties. + /// + /// It is easy to expose as a Tensor as DenseTensor can take Memory Mapping from + /// this. + /// + /// This class is disposable because of the MemoryManager inheritance /// /// - internal class NativeOnnxTensorMemory : MemoryManager, IOrtValueOwner + internal class OrtValueTensor : MemoryManager, IOrtValueOwner { private OrtValue _ortValue; // Disposable - private IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory - private string[] _dataBufferAsString; // string tensor values copied into managed memory + private readonly IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory + private readonly string[] _dataBufferAsString; // string tensor values copied into managed memory /// /// Constructs an instance and takes ownership of ortValue on success /// /// ortValue that is a Tensor - public NativeOnnxTensorMemory(OrtValue ortValue) + public OrtValueTensor(OrtValue ortValue) { Type type = null; int width = 0; @@ -115,7 +121,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue) if (typeof(T) != type) { - var message = String.Format("The NativeOnnxTensorMemory type being instantiated for T = : {0} while supplied OrtValue contains T = {1}", + var message = String.Format("The OrtValueTensor type being instantiated for T = : {0} while supplied OrtValue contains T = {1}", typeof(T), type); throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); } @@ -214,7 +220,7 @@ public NativeOnnxTensorMemory(OrtValue ortValue) public override Span GetSpan() { if (IsDisposed) - throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory)); + throw new ObjectDisposedException(nameof(OrtValueTensor)); Span span = null; unsafe { @@ -226,10 +232,10 @@ public override Span GetSpan() public Memory GetBytesAsStringMemory() { if (IsDisposed) - throw new ObjectDisposedException(nameof(NativeOnnxTensorMemory)); + throw new ObjectDisposedException(nameof(OrtValueTensor)); if (typeof(T) != typeof(string)) - throw new NotSupportedException(nameof(NativeOnnxTensorMemory.GetBytesAsStringMemory) + ": T must be byte"); + throw new NotSupportedException(nameof(OrtValueTensor.GetBytesAsStringMemory) + ": T must be byte"); return (_dataBufferAsString == null) ? new Memory() : new Memory(_dataBufferAsString); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj index e15409216aab..3e4802bf68f6 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj @@ -48,6 +48,7 @@ + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 4ec453eef851..871e90164716 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -588,7 +588,7 @@ private void ThrowWrongInputName() var container = new List(); container.Add(NamedOnnxValue.CreateFromTensor("wrong_name", tensor)); var ex = Assert.Throws(() => session.Run(container)); - Assert.Contains("Invalid Feed Input", ex.Message); + Assert.Contains("Input/output name: 'wrong_name' is not in the metadata", ex.Message); session.Dispose(); } @@ -604,9 +604,8 @@ private void ThrowWrongInputType() var tensor = new DenseTensor(inputDataInt, inputMeta["data_0"].Dimensions); container.Add(NamedOnnxValue.CreateFromTensor("data_0", tensor)); var ex = Assert.Throws(() => session.Run(container)); - var msg = ex.ToString().Substring(0, 101); - // TODO: message is diff in LInux. Use substring match - Assert.Equal("Microsoft.ML.OnnxRuntime.OnnxRuntimeException: [ErrorCode:InvalidArgument] Unexpected input data type", msg); + var msg = ex.ToString(); + Assert.Contains("Tensor element data type discovered", msg); session.Dispose(); } @@ -624,7 +623,7 @@ private void ThrowExtraInputs() container.Add(nov1); container.Add(nov2); var ex = Assert.Throws(() => session.Run(container)); - Assert.StartsWith("[ErrorCode:InvalidArgument] Invalid Feed Input Name", ex.Message); + Assert.Contains("Input/output name: 'extra' is not in the metadata", ex.Message); session.Dispose(); } @@ -653,9 +652,10 @@ private void ThrowWrongOutputName() var inputTensor = tuple.Item3; var inputs = new List { NamedOnnxValue.CreateFromTensor("data_0", inputTensor) }; var outputTensor = new DenseTensor((ReadOnlySpan)new[] { 1, 2 }); - var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) }; - var ex = Assert.Throws(() => session.Run(inputs, outputs)); - Assert.Contains("Invalid Output Name", ex.Message); + // var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) }; + var bad_names = new string[] {"bad_output_name"}; + var ex = Assert.Throws(() => session.Run(inputs, bad_names)); + Assert.Contains("Input/output name: 'bad_output_name' is not in the metadata", ex.Message); session.Dispose(); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj index a373039436e3..58c9cbe11dbd 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj @@ -9,7 +9,7 @@ true true true - $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx\onnx + $(OnnxRuntimeCsharpRoot)\..\cmake\external\onnx;\..\cmake\external\onnx\onnx 7.2 @@ -76,6 +76,7 @@ + @@ -107,6 +108,10 @@ + + + + @@ -131,4 +136,4 @@ - + \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs new file mode 100644 index 000000000000..0d3f3b4d3edd --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxData.cs @@ -0,0 +1,1335 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx/onnx-data.proto3 +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Onnx { + + /// Holder for reflection information generated from onnx/onnx-data.proto3 + public static partial class OnnxDataReflection { + + #region Descriptor + /// File descriptor for onnx/onnx-data.proto3 + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static OnnxDataReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChVvbm54L29ubngtZGF0YS5wcm90bzMSBG9ubngaE29ubngvb25ueC1tbC5w", + "cm90bzMi8AIKDVNlcXVlbmNlUHJvdG8SDAoEbmFtZRgBIAEoCRIRCgllbGVt", + "X3R5cGUYAiABKAUSKAoNdGVuc29yX3ZhbHVlcxgDIAMoCzIRLm9ubnguVGVu", + "c29yUHJvdG8SNQoUc3BhcnNlX3RlbnNvcl92YWx1ZXMYBCADKAsyFy5vbm54", + "LlNwYXJzZVRlbnNvclByb3RvEiwKD3NlcXVlbmNlX3ZhbHVlcxgFIAMoCzIT", + "Lm9ubnguU2VxdWVuY2VQcm90bxIiCgptYXBfdmFsdWVzGAYgAygLMg4ub25u", + "eC5NYXBQcm90bxIsCg9vcHRpb25hbF92YWx1ZXMYByADKAsyEy5vbm54Lk9w", + "dGlvbmFsUHJvdG8iXQoIRGF0YVR5cGUSDQoJVU5ERUZJTkVEEAASCgoGVEVO", + "U09SEAESEQoNU1BBUlNFX1RFTlNPUhACEgwKCFNFUVVFTkNFEAMSBwoDTUFQ", + "EAQSDAoIT1BUSU9OQUwQBSJyCghNYXBQcm90bxIMCgRuYW1lGAEgASgJEhAK", + "CGtleV90eXBlGAIgASgFEgwKBGtleXMYAyADKAMSEwoLc3RyaW5nX2tleXMY", + "BCADKAwSIwoGdmFsdWVzGAUgASgLMhMub25ueC5TZXF1ZW5jZVByb3RvIusC", + "Cg1PcHRpb25hbFByb3RvEgwKBG5hbWUYASABKAkSEQoJZWxlbV90eXBlGAIg", + "ASgFEicKDHRlbnNvcl92YWx1ZRgDIAEoCzIRLm9ubnguVGVuc29yUHJvdG8S", + "NAoTc3BhcnNlX3RlbnNvcl92YWx1ZRgEIAEoCzIXLm9ubnguU3BhcnNlVGVu", + "c29yUHJvdG8SKwoOc2VxdWVuY2VfdmFsdWUYBSABKAsyEy5vbm54LlNlcXVl", + "bmNlUHJvdG8SIQoJbWFwX3ZhbHVlGAYgASgLMg4ub25ueC5NYXBQcm90bxIr", + "Cg5vcHRpb25hbF92YWx1ZRgHIAEoCzITLm9ubnguT3B0aW9uYWxQcm90byJd", + "CghEYXRhVHlwZRINCglVTkRFRklORUQQABIKCgZURU5TT1IQARIRCg1TUEFS", + "U0VfVEVOU09SEAISDAoIU0VRVUVOQ0UQAxIHCgNNQVAQBBIMCghPUFRJT05B", + "TBAFQgJIA2IGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Onnx.OnnxMlReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.SequenceProto), global::Onnx.SequenceProto.Parser, new[]{ "Name", "ElemType", "TensorValues", "SparseTensorValues", "SequenceValues", "MapValues", "OptionalValues" }, null, new[]{ typeof(global::Onnx.SequenceProto.Types.DataType) }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.MapProto), global::Onnx.MapProto.Parser, new[]{ "Name", "KeyType", "Keys", "StringKeys", "Values" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Onnx.OptionalProto), global::Onnx.OptionalProto.Parser, new[]{ "Name", "ElemType", "TensorValue", "SparseTensorValue", "SequenceValue", "MapValue", "OptionalValue" }, null, new[]{ typeof(global::Onnx.OptionalProto.Types.DataType) }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Sequences + /// + /// Defines a dense, ordered, collection of elements that are of homogeneous types. + /// Sequences can be made out of tensors, maps, or sequences. + /// + /// If a sequence is made out of tensors, the tensors must have the same element + /// type (i.e. int32). In some cases, the tensors in a sequence can have different + /// shapes. Whether the tensors can have different shapes or not depends on the + /// type/shape associated with the corresponding "ValueInfo". For example, + /// "Sequence<Tensor<float, [M,N]>" means that all tensors have same shape. However, + /// "Sequence<Tensor<float, [omitted,omitted]>" means they can have different + /// shapes (all of rank 2), where "omitted" means the corresponding dimension has + /// no symbolic/constant value. Finally, "Sequence<Tensor<float, omitted>>" means + /// that the different tensors can have different ranks, when the "shape" itself + /// is omitted from the tensor-type. For a more complete description, refer to + /// https://github.com/onnx/onnx/blob/main/docs/IR.md#static-tensor-shapes. + /// + public sealed partial class SequenceProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SequenceProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto(SequenceProto other) : this() { + name_ = other.name_; + elemType_ = other.elemType_; + tensorValues_ = other.tensorValues_.Clone(); + sparseTensorValues_ = other.sparseTensorValues_.Clone(); + sequenceValues_ = other.sequenceValues_.Clone(); + mapValues_ = other.mapValues_.Clone(); + optionalValues_ = other.optionalValues_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public SequenceProto Clone() { + return new SequenceProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 2; + private int elemType_; + /// + /// The data type of the element. + /// This field MUST have a valid SequenceProto.DataType value + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ElemType { + get { return elemType_; } + set { + elemType_ = value; + } + } + + /// Field number for the "tensor_values" field. + public const int TensorValuesFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_tensorValues_codec + = pb::FieldCodec.ForMessage(26, global::Onnx.TensorProto.Parser); + private readonly pbc::RepeatedField tensorValues_ = new pbc::RepeatedField(); + /// + /// For TensorProto values. + /// When this field is present, the elem_type field MUST be TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField TensorValues { + get { return tensorValues_; } + } + + /// Field number for the "sparse_tensor_values" field. + public const int SparseTensorValuesFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_sparseTensorValues_codec + = pb::FieldCodec.ForMessage(34, global::Onnx.SparseTensorProto.Parser); + private readonly pbc::RepeatedField sparseTensorValues_ = new pbc::RepeatedField(); + /// + /// For SparseTensorProto values. + /// When this field is present, the elem_type field MUST be SPARSE_TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SparseTensorValues { + get { return sparseTensorValues_; } + } + + /// Field number for the "sequence_values" field. + public const int SequenceValuesFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_sequenceValues_codec + = pb::FieldCodec.ForMessage(42, global::Onnx.SequenceProto.Parser); + private readonly pbc::RepeatedField sequenceValues_ = new pbc::RepeatedField(); + /// + /// For SequenceProto values, allowing sequences to be of themselves. + /// When this field is present, the elem_type field MUST be SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField SequenceValues { + get { return sequenceValues_; } + } + + /// Field number for the "map_values" field. + public const int MapValuesFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_mapValues_codec + = pb::FieldCodec.ForMessage(50, global::Onnx.MapProto.Parser); + private readonly pbc::RepeatedField mapValues_ = new pbc::RepeatedField(); + /// + /// For MapProto values. + /// When this field is present, the elem_type field MUST be MAP. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField MapValues { + get { return mapValues_; } + } + + /// Field number for the "optional_values" field. + public const int OptionalValuesFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_optionalValues_codec + = pb::FieldCodec.ForMessage(58, global::Onnx.OptionalProto.Parser); + private readonly pbc::RepeatedField optionalValues_ = new pbc::RepeatedField(); + /// + /// For OptionalProto values. + /// When this field is present, the elem_type field MUST be Optional. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField OptionalValues { + get { return optionalValues_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as SequenceProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(SequenceProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ElemType != other.ElemType) return false; + if(!tensorValues_.Equals(other.tensorValues_)) return false; + if(!sparseTensorValues_.Equals(other.sparseTensorValues_)) return false; + if(!sequenceValues_.Equals(other.sequenceValues_)) return false; + if(!mapValues_.Equals(other.mapValues_)) return false; + if(!optionalValues_.Equals(other.optionalValues_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ElemType != 0) hash ^= ElemType.GetHashCode(); + hash ^= tensorValues_.GetHashCode(); + hash ^= sparseTensorValues_.GetHashCode(); + hash ^= sequenceValues_.GetHashCode(); + hash ^= mapValues_.GetHashCode(); + hash ^= optionalValues_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + tensorValues_.WriteTo(output, _repeated_tensorValues_codec); + sparseTensorValues_.WriteTo(output, _repeated_sparseTensorValues_codec); + sequenceValues_.WriteTo(output, _repeated_sequenceValues_codec); + mapValues_.WriteTo(output, _repeated_mapValues_codec); + optionalValues_.WriteTo(output, _repeated_optionalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + tensorValues_.WriteTo(ref output, _repeated_tensorValues_codec); + sparseTensorValues_.WriteTo(ref output, _repeated_sparseTensorValues_codec); + sequenceValues_.WriteTo(ref output, _repeated_sequenceValues_codec); + mapValues_.WriteTo(ref output, _repeated_mapValues_codec); + optionalValues_.WriteTo(ref output, _repeated_optionalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ElemType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); + } + size += tensorValues_.CalculateSize(_repeated_tensorValues_codec); + size += sparseTensorValues_.CalculateSize(_repeated_sparseTensorValues_codec); + size += sequenceValues_.CalculateSize(_repeated_sequenceValues_codec); + size += mapValues_.CalculateSize(_repeated_mapValues_codec); + size += optionalValues_.CalculateSize(_repeated_optionalValues_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(SequenceProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ElemType != 0) { + ElemType = other.ElemType; + } + tensorValues_.Add(other.tensorValues_); + sparseTensorValues_.Add(other.sparseTensorValues_); + sequenceValues_.Add(other.sequenceValues_); + mapValues_.Add(other.mapValues_); + optionalValues_.Add(other.optionalValues_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + tensorValues_.AddEntriesFrom(input, _repeated_tensorValues_codec); + break; + } + case 34: { + sparseTensorValues_.AddEntriesFrom(input, _repeated_sparseTensorValues_codec); + break; + } + case 42: { + sequenceValues_.AddEntriesFrom(input, _repeated_sequenceValues_codec); + break; + } + case 50: { + mapValues_.AddEntriesFrom(input, _repeated_mapValues_codec); + break; + } + case 58: { + optionalValues_.AddEntriesFrom(input, _repeated_optionalValues_codec); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + tensorValues_.AddEntriesFrom(ref input, _repeated_tensorValues_codec); + break; + } + case 34: { + sparseTensorValues_.AddEntriesFrom(ref input, _repeated_sparseTensorValues_codec); + break; + } + case 42: { + sequenceValues_.AddEntriesFrom(ref input, _repeated_sequenceValues_codec); + break; + } + case 50: { + mapValues_.AddEntriesFrom(ref input, _repeated_mapValues_codec); + break; + } + case 58: { + optionalValues_.AddEntriesFrom(ref input, _repeated_optionalValues_codec); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the SequenceProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum DataType { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + [pbr::OriginalName("TENSOR")] Tensor = 1, + [pbr::OriginalName("SPARSE_TENSOR")] SparseTensor = 2, + [pbr::OriginalName("SEQUENCE")] Sequence = 3, + [pbr::OriginalName("MAP")] Map = 4, + [pbr::OriginalName("OPTIONAL")] Optional = 5, + } + + } + #endregion + + } + + /// + /// Maps + /// + /// Specifies an associative table, defined by keys and values. + /// MapProto is formed with a repeated field of keys (of type INT8, INT16, INT32, + /// INT64, UINT8, UINT16, UINT32, UINT64, or STRING) and values (of type TENSOR, + /// SPARSE_TENSOR, SEQUENCE, or MAP). Key types and value types have to remain + /// the same throughout the instantiation of the MapProto. + /// + public sealed partial class MapProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MapProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto(MapProto other) : this() { + name_ = other.name_; + keyType_ = other.keyType_; + keys_ = other.keys_.Clone(); + stringKeys_ = other.stringKeys_.Clone(); + values_ = other.values_ != null ? other.values_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public MapProto Clone() { + return new MapProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "key_type" field. + public const int KeyTypeFieldNumber = 2; + private int keyType_; + /// + /// The data type of the key. + /// This field MUST have a valid TensorProto.DataType value of + /// INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int KeyType { + get { return keyType_; } + set { + keyType_ = value; + } + } + + /// Field number for the "keys" field. + public const int KeysFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_keys_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField keys_ = new pbc::RepeatedField(); + /// + /// Every element of keys has to be one of the following data types + /// INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, or STRING. + /// The integer cases are represented by the repeated int64 field keys below. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField Keys { + get { return keys_; } + } + + /// Field number for the "string_keys" field. + public const int StringKeysFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_stringKeys_codec + = pb::FieldCodec.ForBytes(34); + private readonly pbc::RepeatedField stringKeys_ = new pbc::RepeatedField(); + /// + /// If keys are strings, they are represented by the repeated bytes field + /// string_keys below. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public pbc::RepeatedField StringKeys { + get { return stringKeys_; } + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 5; + private global::Onnx.SequenceProto values_; + /// + /// MapProto values are represented in a SequenceProto of the same length as the + /// repeated keys field and have to be one of the following data types + /// TENSOR, SPARSE_TENSOR, MAP, SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SequenceProto Values { + get { return values_; } + set { + values_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as MapProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(MapProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (KeyType != other.KeyType) return false; + if(!keys_.Equals(other.keys_)) return false; + if(!stringKeys_.Equals(other.stringKeys_)) return false; + if (!object.Equals(Values, other.Values)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (KeyType != 0) hash ^= KeyType.GetHashCode(); + hash ^= keys_.GetHashCode(); + hash ^= stringKeys_.GetHashCode(); + if (values_ != null) hash ^= Values.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (KeyType != 0) { + output.WriteRawTag(16); + output.WriteInt32(KeyType); + } + keys_.WriteTo(output, _repeated_keys_codec); + stringKeys_.WriteTo(output, _repeated_stringKeys_codec); + if (values_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Values); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (KeyType != 0) { + output.WriteRawTag(16); + output.WriteInt32(KeyType); + } + keys_.WriteTo(ref output, _repeated_keys_codec); + stringKeys_.WriteTo(ref output, _repeated_stringKeys_codec); + if (values_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Values); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (KeyType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KeyType); + } + size += keys_.CalculateSize(_repeated_keys_codec); + size += stringKeys_.CalculateSize(_repeated_stringKeys_codec); + if (values_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Values); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(MapProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.KeyType != 0) { + KeyType = other.KeyType; + } + keys_.Add(other.keys_); + stringKeys_.Add(other.stringKeys_); + if (other.values_ != null) { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + Values.MergeFrom(other.Values); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + KeyType = input.ReadInt32(); + break; + } + case 26: + case 24: { + keys_.AddEntriesFrom(input, _repeated_keys_codec); + break; + } + case 34: { + stringKeys_.AddEntriesFrom(input, _repeated_stringKeys_codec); + break; + } + case 42: { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + input.ReadMessage(Values); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + KeyType = input.ReadInt32(); + break; + } + case 26: + case 24: { + keys_.AddEntriesFrom(ref input, _repeated_keys_codec); + break; + } + case 34: { + stringKeys_.AddEntriesFrom(ref input, _repeated_stringKeys_codec); + break; + } + case 42: { + if (values_ == null) { + Values = new global::Onnx.SequenceProto(); + } + input.ReadMessage(Values); + break; + } + } + } + } + #endif + + } + + /// + /// Optional + /// + public sealed partial class OptionalProto : pb::IMessage + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + , pb::IBufferMessage + #endif + { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OptionalProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static pbr::MessageDescriptor Descriptor { + get { return global::Onnx.OnnxDataReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto(OptionalProto other) : this() { + name_ = other.name_; + elemType_ = other.elemType_; + tensorValue_ = other.tensorValue_ != null ? other.tensorValue_.Clone() : null; + sparseTensorValue_ = other.sparseTensorValue_ != null ? other.sparseTensorValue_.Clone() : null; + sequenceValue_ = other.sequenceValue_ != null ? other.sequenceValue_.Clone() : null; + mapValue_ = other.mapValue_ != null ? other.mapValue_.Clone() : null; + optionalValue_ = other.optionalValue_ != null ? other.optionalValue_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public OptionalProto Clone() { + return new OptionalProto(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "elem_type" field. + public const int ElemTypeFieldNumber = 2; + private int elemType_; + /// + /// The data type of the element, identifies if the OptionalProto value + /// is Tensor, Sparse Tensor, Sequence, Map, or Optional. + /// The type of the optional value MUST match the elem_type specified. + /// This field MUST have a valid OptionalProto.DataType value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int ElemType { + get { return elemType_; } + set { + elemType_ = value; + } + } + + /// Field number for the "tensor_value" field. + public const int TensorValueFieldNumber = 3; + private global::Onnx.TensorProto tensorValue_; + /// + /// For TensorProto value. + /// When this field is present, the elem_type field MUST be TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.TensorProto TensorValue { + get { return tensorValue_; } + set { + tensorValue_ = value; + } + } + + /// Field number for the "sparse_tensor_value" field. + public const int SparseTensorValueFieldNumber = 4; + private global::Onnx.SparseTensorProto sparseTensorValue_; + /// + /// For SparseTensorProto value. + /// When this field is present, the elem_type field MUST be SPARSE_TENSOR. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SparseTensorProto SparseTensorValue { + get { return sparseTensorValue_; } + set { + sparseTensorValue_ = value; + } + } + + /// Field number for the "sequence_value" field. + public const int SequenceValueFieldNumber = 5; + private global::Onnx.SequenceProto sequenceValue_; + /// + /// For SequenceProto value. + /// When this field is present, the elem_type field MUST be SEQUENCE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.SequenceProto SequenceValue { + get { return sequenceValue_; } + set { + sequenceValue_ = value; + } + } + + /// Field number for the "map_value" field. + public const int MapValueFieldNumber = 6; + private global::Onnx.MapProto mapValue_; + /// + /// For MapProto value. + /// When this field is present, the elem_type field MUST be MAP. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.MapProto MapValue { + get { return mapValue_; } + set { + mapValue_ = value; + } + } + + /// Field number for the "optional_value" field. + public const int OptionalValueFieldNumber = 7; + private global::Onnx.OptionalProto optionalValue_; + /// + /// For OptionalProto value, allowing optional to be of itself (completeness) + /// When this field is present, the elem_type field MUST be OPTIONAL. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public global::Onnx.OptionalProto OptionalValue { + get { return optionalValue_; } + set { + optionalValue_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override bool Equals(object other) { + return Equals(other as OptionalProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public bool Equals(OptionalProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ElemType != other.ElemType) return false; + if (!object.Equals(TensorValue, other.TensorValue)) return false; + if (!object.Equals(SparseTensorValue, other.SparseTensorValue)) return false; + if (!object.Equals(SequenceValue, other.SequenceValue)) return false; + if (!object.Equals(MapValue, other.MapValue)) return false; + if (!object.Equals(OptionalValue, other.OptionalValue)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ElemType != 0) hash ^= ElemType.GetHashCode(); + if (tensorValue_ != null) hash ^= TensorValue.GetHashCode(); + if (sparseTensorValue_ != null) hash ^= SparseTensorValue.GetHashCode(); + if (sequenceValue_ != null) hash ^= SequenceValue.GetHashCode(); + if (mapValue_ != null) hash ^= MapValue.GetHashCode(); + if (optionalValue_ != null) hash ^= OptionalValue.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void WriteTo(pb::CodedOutputStream output) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + output.WriteRawMessage(this); + #else + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + if (tensorValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorValue); + } + if (sparseTensorValue_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SparseTensorValue); + } + if (sequenceValue_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SequenceValue); + } + if (mapValue_ != null) { + output.WriteRawTag(50); + output.WriteMessage(MapValue); + } + if (optionalValue_ != null) { + output.WriteRawTag(58); + output.WriteMessage(OptionalValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ElemType != 0) { + output.WriteRawTag(16); + output.WriteInt32(ElemType); + } + if (tensorValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TensorValue); + } + if (sparseTensorValue_ != null) { + output.WriteRawTag(34); + output.WriteMessage(SparseTensorValue); + } + if (sequenceValue_ != null) { + output.WriteRawTag(42); + output.WriteMessage(SequenceValue); + } + if (mapValue_ != null) { + output.WriteRawTag(50); + output.WriteMessage(MapValue); + } + if (optionalValue_ != null) { + output.WriteRawTag(58); + output.WriteMessage(OptionalValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(ref output); + } + } + #endif + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ElemType != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ElemType); + } + if (tensorValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorValue); + } + if (sparseTensorValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SparseTensorValue); + } + if (sequenceValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SequenceValue); + } + if (mapValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MapValue); + } + if (optionalValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(OptionalValue); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(OptionalProto other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ElemType != 0) { + ElemType = other.ElemType; + } + if (other.tensorValue_ != null) { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + TensorValue.MergeFrom(other.TensorValue); + } + if (other.sparseTensorValue_ != null) { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + SparseTensorValue.MergeFrom(other.SparseTensorValue); + } + if (other.sequenceValue_ != null) { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + SequenceValue.MergeFrom(other.SequenceValue); + } + if (other.mapValue_ != null) { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + MapValue.MergeFrom(other.MapValue); + } + if (other.optionalValue_ != null) { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + OptionalValue.MergeFrom(other.OptionalValue); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public void MergeFrom(pb::CodedInputStream input) { + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + input.ReadRawMessage(this); + #else + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + input.ReadMessage(TensorValue); + break; + } + case 34: { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + input.ReadMessage(SparseTensorValue); + break; + } + case 42: { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + input.ReadMessage(SequenceValue); + break; + } + case 50: { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + input.ReadMessage(MapValue); + break; + } + case 58: { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + input.ReadMessage(OptionalValue); + break; + } + } + } + #endif + } + + #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + ElemType = input.ReadInt32(); + break; + } + case 26: { + if (tensorValue_ == null) { + TensorValue = new global::Onnx.TensorProto(); + } + input.ReadMessage(TensorValue); + break; + } + case 34: { + if (sparseTensorValue_ == null) { + SparseTensorValue = new global::Onnx.SparseTensorProto(); + } + input.ReadMessage(SparseTensorValue); + break; + } + case 42: { + if (sequenceValue_ == null) { + SequenceValue = new global::Onnx.SequenceProto(); + } + input.ReadMessage(SequenceValue); + break; + } + case 50: { + if (mapValue_ == null) { + MapValue = new global::Onnx.MapProto(); + } + input.ReadMessage(MapValue); + break; + } + case 58: { + if (optionalValue_ == null) { + OptionalValue = new global::Onnx.OptionalProto(); + } + input.ReadMessage(OptionalValue); + break; + } + } + } + } + #endif + + #region Nested types + /// Container for nested types declared in the OptionalProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] + public static partial class Types { + public enum DataType { + [pbr::OriginalName("UNDEFINED")] Undefined = 0, + [pbr::OriginalName("TENSOR")] Tensor = 1, + [pbr::OriginalName("SPARSE_TENSOR")] SparseTensor = 2, + [pbr::OriginalName("SEQUENCE")] Sequence = 3, + [pbr::OriginalName("MAP")] Map = 4, + [pbr::OriginalName("OPTIONAL")] Optional = 5, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs index 8805b95839f2..72686b377527 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs @@ -3864,7 +3864,7 @@ public sealed partial class TensorProto : pb::IMessage /// float16 values must be bit-wise converted to an uint16_t prior /// to writing to the buffer. /// When this field is present, the data_type field MUST be - /// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + /// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16 or BFLOAT16 /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs index d913aaa5b966..6bf4d702c9ce 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; +using System.Net.NetworkInformation; +using Google.Protobuf; using Microsoft.ML.OnnxRuntime.Tensors; using Xunit; @@ -47,128 +50,87 @@ internal static float[] LoadTensorFromEmbeddedResource(string path) return tensorData.ToArray(); } - static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary nodeMetaDict) + static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, string nodeName, NodeMetadata nodeMeta) { - var intDims = new int[tensor.Dims.Count]; - for (int i = 0; i < tensor.Dims.Count; i++) + if (nodeMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) { - intDims[i] = (int)tensor.Dims[i]; + throw new InvalidDataException($"Metadata for: '{nodeName}' has a type: '{nodeMeta.OnnxValueType}'" + + $" but loading as tensor: '{tensor.Name}'"); } - NodeMetadata nodeMeta = null; - string nodeName = string.Empty; + var protoDt = (Tensors.TensorElementType)tensor.DataType; + var metaElementType = nodeMeta.ElementDataType; + if (!((protoDt == metaElementType) || + (protoDt == TensorElementType.UInt16 && + (metaElementType == TensorElementType.BFloat16 || metaElementType == TensorElementType.Float16)))) + throw new InvalidDataException($"Loaded tensor type: {protoDt} is expected to be equal to: {metaElementType}"); - if (nodeMetaDict.Count == 1) + // Tensors within Sequences may have no dimensions as the standard allows + // different dimensions for each tensor element of the sequence + if (nodeMeta.Dimensions.Length > 0 && nodeMeta.Dimensions.Length != tensor.Dims.Count) { - nodeMeta = nodeMetaDict.Values.First(); - nodeName = nodeMetaDict.Keys.First(); // valid for single node input + throw new InvalidDataException($"node: '{nodeName}' nodeMeta.Dim.Length: {nodeMeta.Dimensions.Length} " + + $"is expected to be equal to tensor.Dims.Count {tensor.Dims.Count}"); } - else if (nodeMetaDict.Count > 1) - { - if (tensor.Name.Length > 0) - { - nodeMeta = nodeMetaDict[tensor.Name]; - nodeName = tensor.Name; - if (!nodeMeta.IsTensor) - throw new Exception("LoadTensorFromFile can load Tensor types only: " + nodeName); - } - else - { - bool matchfound = false; - // try to find from matching type and shape - foreach (var key in nodeMetaDict.Keys) - { - var meta = nodeMetaDict[key]; - if (!meta.IsTensor) - throw new Exception("LoadTensorFromFile can load Tensor types only"); - if ((Tensors.TensorElementType)tensor.DataType == meta.ElementDataType && tensor.Dims.Count == meta.Dimensions.Length) - { - int i = 0; - for (; i < meta.Dimensions.Length; i++) - { - if (meta.Dimensions[i] != -1 && meta.Dimensions[i] != intDims[i]) - { - break; - } - } - if (i >= meta.Dimensions.Length) - { - matchfound = true; - nodeMeta = meta; - nodeName = key; - break; - } - } - } - if (!matchfound) - { - // throw error - throw new Exception($"No Matching Tensor found in InputOutputMetadata corresponding to the serialized tensor specified"); - } - } - } - else + var intDims = new int[tensor.Dims.Count]; + for (int i = 0; i < tensor.Dims.Count; i++) { - // throw error - throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements"); + intDims[i] = (int)tensor.Dims[i]; } - if (nodeMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) - throw new Exception("LoadTensorFromFile can load Dense Tensor types only"); - - var protoDt = (Tensors.TensorElementType)tensor.DataType; - if (!((protoDt == nodeMeta.ElementDataType) || - (protoDt == TensorElementType.UInt16 && - (nodeMeta.ElementDataType == TensorElementType.BFloat16 || nodeMeta.ElementDataType == TensorElementType.Float16)))) - throw new Exception($"{tensor.DataType.ToString()} is expected to be equal to: " + nodeMeta.ElementDataType.ToString()); - - if (nodeMeta.Dimensions.Length != tensor.Dims.Count) - throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}"); - for (int i = 0; i < nodeMeta.Dimensions.Length; i++) { - if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != intDims[i])) - throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]"); + if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != tensor.Dims[i])) + throw new InvalidDataException($"Node: '{nodeName}' dimension at idx {i} is {nodeMeta.Dimensions}[{i}] " + + $"is expected to either be -1 or {tensor.Dims[i]}"); } - var elementType = nodeMeta.ElementDataType; + // element type for Float16 and BFloat16 in the loaded tensor would always be uint16, so + // we want to use element type from metadata + if (protoDt == TensorElementType.String) + return CreateNamedOnnxValueFromStringTensor(tensor.StringData, nodeName, intDims); + + return CreateNamedOnnxValueFromTensorRawData(nodeName, tensor.RawData.ToArray(), metaElementType, intDims); + } + + internal static NamedOnnxValue CreateNamedOnnxValueFromTensorRawData(string nodeName, byte[] rawData, TensorElementType elementType, int[] intDims) + { switch (elementType) { case TensorElementType.Float: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(float), intDims); case TensorElementType.Double: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(double), intDims); case TensorElementType.Int32: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(int), intDims); case TensorElementType.UInt32: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(uint), intDims); case TensorElementType.Int16: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(short), intDims); case TensorElementType.UInt16: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); case TensorElementType.Int64: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(long), intDims); case TensorElementType.UInt64: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ulong), intDims); case TensorElementType.UInt8: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(byte), intDims); case TensorElementType.Int8: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(sbyte), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(sbyte), intDims); case TensorElementType.Bool: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(bool), intDims); case TensorElementType.Float16: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); case TensorElementType.BFloat16: - return CreateNamedOnnxValueFromRawData(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims); - case TensorElementType.String: - return CreateNamedOnnxValueFromString(tensor, intDims); + return CreateNamedOnnxValueFromRawData(nodeName, rawData, sizeof(ushort), intDims); default: - throw new Exception($"Tensors of type: " + nodeMeta.ElementType.ToString() + " not currently supported in the LoadTensorFromEmbeddedResource"); + throw new InvalidDataException($"Tensors of type: " + elementType.ToString() + + " not currently supported here, use: CreateNamedOnnxValueFromStringTensor."); } } - internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IReadOnlyDictionary nodeMetaDict) + internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, string nodeName, NodeMetadata nodeMeta) { Onnx.TensorProto tensor = null; @@ -179,20 +141,269 @@ internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IRe tensor = Onnx.TensorProto.Parser.ParseFrom(stream); } - return LoadTensorPb(tensor, nodeMetaDict); + return LoadTensorPb(tensor, nodeName, nodeMeta); } - internal static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary nodeMetaDict) + internal static NamedOnnxValue LoadOnnxValueFromFilePb(string fullFilename, string nodeName, NodeMetadata nodeMeta) { + // No sparse tensor support yet //Set buffer size to 4MB int readBufferSize = 4194304; - Onnx.TensorProto tensor = null; - using (var file = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize)) + using (var file = new FileStream(fullFilename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize)) { - tensor = Onnx.TensorProto.Parser.ParseFrom(file); + switch (nodeMeta.OnnxValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + { + var tensor = Onnx.TensorProto.Parser.ParseFrom(file); + return LoadTensorPb(tensor, nodeName, nodeMeta); + } + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + var sequence = Onnx.SequenceProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromSequence(sequence, nodeName, nodeMeta); + } + case OnnxValueType.ONNX_TYPE_MAP: + { + var map = Onnx.MapProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromMap(map, nodeName, nodeMeta); + } + + case OnnxValueType.ONNX_TYPE_OPTIONAL: + { + var opt = Onnx.OptionalProto.Parser.ParseFrom(file); + return CreateNamedOnnxValueFromOptional(opt, nodeName, nodeMeta); + } + default: + throw new ArgumentException($"Unable to load value type {nodeMeta.OnnxValueType} not implemented"); + } } + } - return LoadTensorPb(tensor, nodeMetaDict); + private static void SequenceCheckMatchOnnxType(string nodeName, SequenceMetadata meta, + OnnxValueType onnxType) + { + if (meta.ElementMeta.OnnxValueType == onnxType) + return; + + throw new InvalidDataException($"Sequence node: '{nodeName}' " + + $"has element type: '{onnxType}'" + + $" expected: '{meta.ElementMeta.OnnxValueType}'"); + } + + private static string MakeSequenceElementName(string nodeName, string seqName, int seqNum) + { + if (seqName.Length > 0) + return $"seq.{nodeName}.data.{seqName}.{seqNum}"; + else + return $"seq.{nodeName}.data._.{seqNum}"; + } + + internal static NamedOnnxValue CreateNamedOnnxValueFromSequence(Onnx.SequenceProto sequence, string nodeName, NodeMetadata nodeMeta) + { + var sequenceMeta = nodeMeta.AsSequenceMetadata(); + var elemMeta = sequenceMeta.ElementMeta; + + int seqNum = 0; + var seqElemType = (Onnx.SequenceProto.Types.DataType)sequence.ElemType; + switch (seqElemType) + { + case Onnx.SequenceProto.Types.DataType.Tensor: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_TENSOR); + var sequenceOfTensors = new List(sequence.TensorValues.Count); + foreach (var tensor in sequence.TensorValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + var namedOnnxValue = LoadTensorPb(tensor, elemName, elemMeta); + sequenceOfTensors.Add(namedOnnxValue); + } + return NamedOnnxValue.CreateFromSequence(nodeName, sequenceOfTensors); + } + case Onnx.SequenceProto.Types.DataType.Sequence: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_SEQUENCE); + var seqOfSequences = new List(sequence.SequenceValues.Count); + foreach (var s in sequence.SequenceValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfSequences.Add(CreateNamedOnnxValueFromSequence(s, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfSequences); + } + case Onnx.SequenceProto.Types.DataType.Map: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_MAP); + var seqOfMaps = new List(sequence.MapValues.Count); + foreach (var m in sequence.MapValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfMaps.Add(CreateNamedOnnxValueFromMap(m, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfMaps); + } + case Onnx.SequenceProto.Types.DataType.Optional: + { + SequenceCheckMatchOnnxType(nodeName, sequenceMeta, OnnxValueType.ONNX_TYPE_OPTIONAL); + var seqOfOpts = new List(sequence.OptionalValues.Count); + foreach (var opt in sequence.OptionalValues) + { + var elemName = MakeSequenceElementName(nodeName, sequence.Name, seqNum++); + seqOfOpts.Add(CreateNamedOnnxValueFromOptional(opt, elemName, elemMeta)); + } + return NamedOnnxValue.CreateFromSequence(nodeName, seqOfOpts); + } + default: + throw new NotImplementedException($"Sequence test data loading does not support element type: " + + $"'{seqElemType}'"); + } + + } + + internal static NamedOnnxValue CastAndCreateFromMapKeys(string name, TensorElementType elementType, IList keys) + { + switch (elementType) + { + case TensorElementType.Float: + return CastAndCreateTensor(name, keys); + case TensorElementType.Double: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int32: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt32: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int16: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt16: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int64: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt64: + return CastAndCreateTensor(name, keys); + case TensorElementType.UInt8: + return CastAndCreateTensor(name, keys); + case TensorElementType.Int8: + return CastAndCreateTensor(name, keys); + case TensorElementType.Bool: + return CastAndCreateTensor(name, keys); + case TensorElementType.Float16: + return CastAndCreateTensor(name, keys); + case TensorElementType.BFloat16: + return CastAndCreateTensor(name, keys); + default: + throw new NotImplementedException($"Tensors of type: " + elementType.ToString() + + " not currently supported here, use: CreateNamedOnnxValueFromStringTensor."); + } + } + + /// + /// All the keys in maps are stored as an array of longs, so + /// to create a real tensor we need to cast to create a continuous buffer + /// essentially packing it as a raw data. + /// + /// + /// + /// + /// + /// + /// + internal static NamedOnnxValue CastAndCreateTensor(string name, IList elements) + { + // Create raw data + T[] castKeys = new T[elements.Count]; + if (typeof(T) == typeof(Float16) || typeof(T) == typeof(BFloat16)) + { + for (int i = 0; i < elements.Count; i++) + { + var obj = Convert.ChangeType(elements[i], typeof(ushort)); + if (obj == null) + { + throw new InvalidDataException($"Conversion from long to {typeof(T)} failed"); + } + castKeys[i] = (T)obj; + } + } + else + { + for (int i = 0; i < elements.Count; i++) + { + var obj = (T)Convert.ChangeType(elements[i], typeof(T)); + if (obj == null) + { + throw new InvalidDataException($"Conversion from long to {typeof(T)} failed"); + } + castKeys[i] = (T)obj; + } + } + var tensor = new DenseTensor(castKeys, new int[] { elements.Count }); + return NamedOnnxValue.CreateFromTensor(name, tensor); + } + + internal static NamedOnnxValue CreateNamedOnnxValueFromMap(Onnx.MapProto map, string nodeName, NodeMetadata nodeMetadata) + { + // See GH issue https://github.com/onnx/onnx/issues/5072 + throw new NotImplementedException($"Loading map node: '{nodeName}' not implemented yet"); + + //var mapMeta = nodeMetadata.AsMapMetadata(); + + //if ((TensorElementType)map.KeyType != mapMeta.KeyDataType) + //{ + // throw new InvalidDataException($"Node: '{nodeName}' map key type expected: " + + // $"'{mapMeta.KeyDataType}', loaded from test data: '{(TensorElementType)map.KeyType}'"); + //} + + //// temp non-generic(!) container + //NamedOnnxValue keysTensor; + //if (mapMeta.KeyDataType == TensorElementType.String) + //{ + // keysTensor = CreateNamedOnnxValueFromStringTensor(map.StringKeys, nodeName, new int[] { map.StringKeys.Count }); + //} + //else + //{ + // keysTensor = CastAndCreateFromMapKeys(nodeName, mapMeta.KeyDataType, map.Keys); + //} + + //switch ((Onnx.SequenceProto.Types.DataType)map.Values.ElemType) + //{ + // case Onnx.SequenceProto.Types.DataType.Tensor: + // var tensorCount = map.Values.TensorValues.Count; + // break; + // default: + // throw new NotImplementedException("Does not support map value type other than a tensor"); + //} + + //return new NamedOnnxValue(string.Empty, new Object(), OnnxValueType.ONNX_TYPE_UNKNOWN); + } + + internal static NamedOnnxValue CreateNamedOnnxValueFromOptional(Onnx.OptionalProto optional, string nodeName, NodeMetadata nodeMetadata) + { + var meta = nodeMetadata.AsOptionalMetadata().ElementMeta; + switch((Onnx.OptionalProto.Types.DataType)optional.ElemType) + { + case Onnx.OptionalProto.Types.DataType.Tensor: + { + var tensor = optional.TensorValue; + return LoadTensorPb(tensor, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Sequence: + { + var sequence = optional.SequenceValue; + return CreateNamedOnnxValueFromSequence(sequence, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Map: + { + var map = optional.MapValue; + return CreateNamedOnnxValueFromMap(map, nodeName, meta); + } + case Onnx.OptionalProto.Types.DataType.Optional: + throw new NotImplementedException($"Unable to load '{nodeName}' optional contained within optional"); + default: + // Test data contains OptionalProto with the contained element type undefined. + // the premise is, if the element is not fed as an input, we should not care + // what Onnx type it is. However, we do not need to support AFAIK such inputs + // since the value for them could never be supplied. + throw new NotImplementedException($"Unable to load '{nodeName}' optional element type of: {(Onnx.OptionalProto.Types.DataType)optional.ElemType} type"); + } } internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, byte[] rawData, int elemWidth, int[] dimensions) @@ -218,21 +429,17 @@ internal static NamedOnnxValue CreateNamedOnnxValueFromRawData(string name, b return NamedOnnxValue.CreateFromTensor(name, dt); } - internal static NamedOnnxValue CreateNamedOnnxValueFromString(Onnx.TensorProto tensor, int[] dimensions) - { - if (tensor.DataType != (int)Onnx.TensorProto.Types.DataType.String) - { - throw new ArgumentException("Expecting string data"); - } - - string[] strArray = new string[tensor.StringData.Count]; - for (int i = 0; i < tensor.StringData.Count; ++i) + internal static NamedOnnxValue CreateNamedOnnxValueFromStringTensor(IList strings, + string nodeName, int[] dimensions) + { + string[] strArray = new string[strings.Count]; + for (int i = 0; i < strings.Count; ++i) { - strArray[i] = System.Text.Encoding.UTF8.GetString(tensor.StringData[i].ToByteArray()); + strArray[i] = System.Text.Encoding.UTF8.GetString(strings[i].ToByteArray()); } var dt = new DenseTensor(strArray, dimensions); - return NamedOnnxValue.CreateFromTensor(tensor.Name, dt); + return NamedOnnxValue.CreateFromTensor(nodeName, dt); } internal static float[] LoadTensorFromFile(string filename, bool skipheader = true) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 351a80e86854..6e4fce8ee2f2 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -3,6 +3,7 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; +using System.Text.RegularExpressions; using Microsoft.ML.OnnxRuntime.Tensors; using Xunit; @@ -40,8 +41,8 @@ public void CanCreateAndDisposeSessionWithModelPath() { Assert.NotNull(session); Assert.NotNull(session.InputMetadata); - Assert.Equal(1, session.InputMetadata.Count); // 1 input node - Assert.True(session.InputMetadata.ContainsKey("data_0")); // input node name + Assert.Equal(1, session.InputMetadata.Count); // 1 input nodeMeta + Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); Assert.True(session.InputMetadata["data_0"].IsTensor); var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; @@ -52,8 +53,8 @@ public void CanCreateAndDisposeSessionWithModelPath() } Assert.NotNull(session.OutputMetadata); - Assert.Equal(1, session.OutputMetadata.Count); // 1 output node - Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output node name + Assert.Equal(1, session.OutputMetadata.Count); // 1 output nodeMeta + Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; @@ -246,9 +247,11 @@ private void TestTensorRTProviderOptions() { "fp16_test_tiny_yolov2", "ImageScaler is not a registered function/op"}, { "fp16_coreml_FNS-Candy", "ImageScaler is not a registered function/op" }, { "fp16_coreml_LinearRegression_NYCTaxi", "Error in Node:featureVectorizer : No Op registered for FeatureVectorizer with domain_version of 1"}, - { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." }, + // { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." }, { "test_mnist", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile" }, - { "BERT_Squad", "Could not find an implementation for the node bert / embeddings / one_hot:OneHot(9)" }, + { "BERT_Squad", "Could not find an implementation for the nodeMeta bert / embeddings / one_hot:OneHot(9)" }, + { "test_BERT_Squad", "Test tensor data element type does not match metadata: Int64 is expected to be equal to: Float" }, + { "keras_prelu_ImageNet_small", "Unable to match file: input_1.pb to input/output metadata"}, { "mlperf_ssd_mobilenet_300", "Could not find file output_0.pb" }, { "tf_resnet_v1_50", "result mismatch when Conv BN Fusion is applied" }, { "tf_resnet_v1_101", "result mismatch when Conv BN Fusion is applied" }, @@ -256,101 +259,93 @@ private void TestTensorRTProviderOptions() { "cntk_simple_seg", "Bad onnx test output caused by wrong SAME_UPPER/SAME_LOWER for ConvTranspose" }, { "coreml_Imputer-LogisticRegression_sklearn_load_breast_cancer", "Can't determine model file name" }, { "mask_rcnn_keras", "Model should be edited to remove the extra outputs" }, - { "test_max_float64", "node test error"}, - { "test_min_uint8", "node test error"}, - { "test_mod_mixed_sign_float64", "node test error"}, - { "test_momentum", "node test error"}, - { "test_max_uint16", "node test error"}, - { "test_resize_downsample_scales_linear_align_corners", "node test error"}, - { "test_adagrad_multiple", "node test error"}, - { "test_einsum_inner_prod", "node test error"}, - { "test_sequence_insert_at_back", "node test error"}, - { "test_mod_mixed_sign_int8", "node test error"}, - { "test_maxunpool_export_with_output_shape", "node test error"}, - { "test_min_int16", "node test error"}, - { "test_adagrad", "node test error"}, - { "test_min_float64", "node test error"}, - { "test_max_int16", "node test error"}, - { "test_sequence_insert_at_front", "node test error"}, - { "test_training_dropout_default", "node test error"}, - { "test_training_dropout", "node test error"}, - { "test_adam", "node test error"}, - { "test_training_dropout_mask", "node test error"}, - { "test_clip_default_int8_inbounds", "node test error"}, - { "test_eyelike_with_dtype", "node test error"}, - { "test_cast_STRING_to_FLOAT", "node test error"}, - { "test_cast_FLOAT_to_DOUBLE", "node test error"}, - { "test_cast_BFLOAT16_to_FLOAT", "node test error"}, - { "test_cast_FLOAT_to_BFLOAT16", "node test error"}, + + { "test_maxunpool_export_with_output_shape", "results mismatch"}, + + { "test_min_int8", "Could not find an implementation for Min(13) node with name"}, + { "test_min_uint8", "Could not find an implementation for Min(13) node with name"}, + { "test_min_int16", "Could not find an implementation for Min(13) node with name"}, + { "test_min_uint16", "Could not find an implementation for Min(13) node with name"}, + + { "test_max_int8", "Could not find an implementation for Max(13) node with name"}, + { "test_max_uint8", "Could not find an implementation for Max(13) node with name"}, + { "test_max_int16", "Could not find an implementation for Max(13) node with name"}, + { "test_max_uint16", "Could not find an implementation for Max(13) nodeMeta with name '"}, + + { "test_mul_uint8", "Could not find an implementation for Mul(14) node with name" }, + + { "test_clip_default_int8_inbounds", "nodeMeta test error"}, + { "test_eyelike_with_dtype", "nodeMeta test error"}, + { "test_cast_STRING_to_FLOAT", "nodeMeta test error"}, + { "test_cast_FLOAT_to_DOUBLE", "nodeMeta test error"}, + { "test_cast_BFLOAT16_to_FLOAT", "nodeMeta test error"}, + { "test_cast_FLOAT_to_BFLOAT16", "nodeMeta test error"}, { "test_cast_FLOAT_to_STRING", "Output strings can not be compared exactly"}, - { "test_castlike_STRING_to_FLOAT", "node test error"}, - { "test_castlike_STRING_to_FLOAT_expanded", "node test error"}, - { "test_castlike_FLOAT16_to_DOUBLE", "node test error"}, - { "test_castlike_FLOAT16_to_DOUBLE_expanded", "node test error"}, - { "test_castlike_FLOAT_to_DOUBLE", "node test error"}, - { "test_castlike_FLOAT_to_DOUBLE_expanded", "node test error"}, + { "test_castlike_STRING_to_FLOAT", "nodeMeta test error"}, + { "test_castlike_STRING_to_FLOAT_expanded", "nodeMeta test error"}, + { "test_castlike_FLOAT16_to_DOUBLE", "nodeMeta test error"}, + { "test_castlike_FLOAT16_to_DOUBLE_expanded", "nodeMeta test error"}, + { "test_castlike_FLOAT_to_DOUBLE", "nodeMeta test error"}, + { "test_castlike_FLOAT_to_DOUBLE_expanded", "nodeMeta test error"}, { "test_castlike_BFLOAT16_to_FLOAT", "Length is expected to be equal to Count (metadata and expected data mismatch) "}, { "test_castlike_BFLOAT16_to_FLOAT_expanded", "Length is expected to be equal to Count metadata and expected data mismatch"}, - { "test_castlike_FLOAT_to_BFLOAT16", "node test error"}, - { "test_castlike_FLOAT_to_BFLOAT16_expanded", "node test error"}, - { "test_castlike_FLOAT_to_STRING", "node test error"}, - { "test_castlike_FLOAT_to_STRING_expanded", "node test error"}, - { "test_bitshift_right_uint16", "node test error"}, - { "test_bitshift_left_uint16", "node test error"}, - { "test_pow_types_float32_uint64", "node test error"}, - { "test_max_uint8", "node test error"}, - { "test_momentum_multiple", "node test error"}, - { "test_pow_types_float32_uint32", "node test error"}, - { "test_if_seq", "sequence type is not supported in test infra."}, - { "test_resize_downsample_scales_cubic_align_corners", "node test error"}, - { "test_einsum_batch_matmul", "node test error"}, - { "test_nesterov_momentum", "node test error"}, - { "test_min_uint16", "node test error"}, - { "test_adam_multiple", "node test error"}, - { "test_loop13_seq", "sequence type is not supported in test infra." }, - { "test_training_dropout_default_mask", "node test error"}, - { "test_min_int8", "node test error"}, - { "test_identity_sequence", "data type not supported"}, + { "test_castlike_FLOAT_to_BFLOAT16", "Length is expected to be equal to Count. Testdata dims length do not match that of model metadata"}, + { "test_castlike_FLOAT_to_BFLOAT16_expanded", "Length is expected to be equal to Count"}, + { "test_castlike_FLOAT_to_STRING", "string comparison does not match due to float rounding"}, + { "test_castlike_FLOAT_to_STRING_expanded", "string comparison does not match due to float rounding"}, + + { "test_bitshift_right_uint16", "Could not find an implementation for BitShift(11) nodeMeta with name ''"}, + { "test_bitshift_left_uint16", "Could not find an implementation for BitShift(11)"}, + + { "test_pow_types_float32_uint64", "Could not find an implementation for Pow(15) node with name ''"}, + { "test_pow_types_float32_uint32", "Could not find an implementation for Pow(15) node with name ''"}, + + { "test_resize_downsample_scales_cubic_align_corners", "Results mismatch"}, + { "test_resize_downsample_scales_linear_align_corners", "Results mismatch"}, + { "test_gru_batchwise", "batchwise operations not supported"}, - { "test_lstm_batchwise", "batchwise operations not supported"}, + { "test_lstm_batchwise", "Batchwise recurrent operations(layout == 1) are not supported.If you need support create a github issue with justification."}, { "test_simple_rnn_batchwise", "batchwise operations not supported"}, { "test_batchnorm_example_training_mode", "opset14 version not implemented yet"}, - { "test_bernoulli", "random generator"}, - { "test_bernoulli_seed", "random generator"}, - { "test_bernoulli_double", "random generator"}, - { "test_bernoulli_expanded", "random generator"}, - { "test_bernoulli_seed_expanded", "random generator"}, - { "test_bernoulli_double_expanded", "random generator"}, - { "test_shape", "opset15 version not implemented yet"}, - { "test_optional_get_element", "optional type is not supported in test infra."}, - { "test_optional_get_element_sequence", "optional type is not supported in test infra."}, - { "test_identity_opt", "optional type is not supported in test infra." }, - { "test_if_opt", "optional type is not supported in test infra." }, - { "test_loop16_seq_none", "sequence type is not supported in test infra." }, - { "test_sequence_map_extract_shapes", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_1_tensor", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_1_sequence_1_tensor", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_2_sequences", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_2_sequences_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_2_sequences_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_extract_shapes_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_1_sequence_1_tensor_expanded", "sequence type is not supported in test infra." }, - { "test_sequence_map_add_2_sequences", "sequence type is not supported in test infra." }, - { "test_sequence_map_identity_1_sequence", "sequence type is not supported in test infra." }, - { "BERT-Squad-int8", "training domain"}, - { "YOLOv3-12-int8", "training_domain"}, + + { "test_bernoulli", "random generator, results mismatch"}, + { "test_bernoulli_seed", "random generator, results mismatch"}, + { "test_bernoulli_double", "random generator, results mismatch"}, + { "test_bernoulli_expanded", "random generator, results mismatch"}, + { "test_bernoulli_seed_expanded", "random generator, results mismatch"}, + { "test_bernoulli_double_expanded", "random generator, results mismatch"}, + // the expansion of Softplus uses Exp(1). ORT has a Softplus kernel, so testing the expansion is // unnecessary and fails as ORT support for Exp started at opset 6 (as ORT didn't exist until opset 7). - { "test_softplus_example_expanded", "Not applicable"}, - { "test_softplus_expanded", "Not applicable"}, - { "test_col2im_pads", "due to a typo in test data"}, - { "test_optional_has_element_empty_optional_input", "C# API doesn't support optional input"}, - { "test_optional_get_element_optional_tensor", "C# API doesn't support optional input"}, - { "test_optional_get_element_optional_sequence", "C# API doesn't support optional input"}, - { "test_optional_has_element_tensor_input", "C# API doesn't support optional input"}, - { "test_optional_has_element_optional_input", "C# API doesn't support optional input"}, + + { "test_clip_default_int8_max_expanded", "Could not find an implementation for Less(13) nodeMeta with name ''" }, + { "test_softplus_expanded", "Could not find an implementation for Exp(1) node with name ''"}, + { "test_softplus_example_expanded", "Could not find an implementation for Exp(1) node with name ''"}, + { "test_div_uint8", "Could not find an implementation for Div(14) nodeMeta with name ''"}, + { "test_add_uint8", "Opset18 Could not find an implementation for Add(14) nodeMeta with name ''"}, + { "test_col2im_pads", "Results mismatch due to a typo in test data"}, + + { "test_optional_has_element_empty_optional_input", "OptionalProto test metadata. Unable to load 'optional_input' optional element type of: Undefined type"}, + { "test_loop13_seq", "3rd input is an empty sequence. Ort API does not tolerate empty seq: Number of values should be at least 1" }, + + // Training tests + { "BERT-Squad-int8", "training domain"}, + { "YOLOv3-12-int8", "training_domain"}, + + { "test_training_dropout_default", "results mismatch"}, + { "test_training_dropout_default_mask", "Results mismatch"}, + { "test_training_dropout", "results mismatch"}, + { "test_training_dropout_mask", "results mismatch."}, + + { "test_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + { "test_momentum_multiple", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + { "test_nesterov_momentum", "ai.onnx.preview.training:Momentum(-1) is not a registered function/op"}, + + { "test_adam", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"}, + { "test_adam_multiple", "ai.onnx.preview.training:Adam(-1) is not a registered function/op"}, + + { "test_adagrad", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"}, + { "test_adagrad_multiple", "ai.onnx.preview.training:Adagrad(-1) is not a registered function/op"}, }; // The following models fails on nocontribops win CI @@ -453,6 +448,47 @@ public static IEnumerable GetSkippedModelForTest() } } + string MatchInputOutputWithFile(string fileName, InferenceSession session, bool input, out NodeMetadata result) + { + string nodeName = string.Empty; + result = null; + var names = (input) ? session.InputNames : session.OutputNames; + var metadata = (input) ? session.InputMetadata : session.OutputMetadata; + string regEx = (input) ? @"input_(\d{1,}).pb" : @"output_(\d{1,}).pb"; + + // Extract the number from the file name, if not try to match the input/output name with the name of the file. + try + { + // captures start at index 1 + var group = Regex.Matches(fileName, regEx).Single().Groups[1]; + var num = int.Parse(group.Value); + if (num >= 0 && num < names.Count) + { + nodeName = names[num]; + result = metadata[nodeName]; + } + else + { + throw new InvalidDataException($"Filename '{fileName}' input/output number '{num}' is out of range for '{names.Count}' inputs/outputs"); + } + } + catch (Exception) + { + // Either does not match or can not parse the number + } + + if (result is null) + { + // try matching the file name directly against the input/output name + if (!metadata.TryGetValue(fileName, out result)) + { + throw new InvalidDataException($"Unable to match file: {fileName} to input/output metadata"); + } + nodeName = fileName; + } + return nodeName; + } + [Theory(DisplayName = "TestPreTrainedModels")] [MemberData(nameof(GetModelsForTest))] [MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")] @@ -494,6 +530,7 @@ private void TestPreTrainedModels(string opsetDir, string modelName) using (var session = new InferenceSession(onnxModelFileName)) { var inMeta = session.InputMetadata; + var outMeta = session.OutputMetadata; string testDataDirNamePattern = "test_data*"; if (opset == "opset9" && modelName == "LSTM_Seq_lens_unpacked") { @@ -501,15 +538,17 @@ private void TestPreTrainedModels(string opsetDir, string modelName) } foreach (var testDataDir in modelDir.EnumerateDirectories(testDataDirNamePattern)) { - var inputContainer = new List(); - var outputContainer = new List(); + var inputContainer = new List(inMeta.Count); + var outputContainer = new List(outMeta.Count); foreach (var f in testDataDir.EnumerateFiles("input_*.pb")) { - inputContainer.Add(TestDataLoader.LoadTensorFromFilePb(f.FullName, inMeta)); + var nodeName = MatchInputOutputWithFile(f.Name, session, true, out NodeMetadata nodeMeta); + inputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); } foreach (var f in testDataDir.EnumerateFiles("output_*.pb")) { - outputContainer.Add(TestDataLoader.LoadTensorFromFilePb(f.FullName, session.OutputMetadata)); + var nodeName = MatchInputOutputWithFile(f.Name, session, false, out NodeMetadata nodeMeta); + outputContainer.Add(TestDataLoader.LoadOnnxValueFromFilePb(f.FullName, nodeName, nodeMeta)); } using (var resultCollection = session.Run(inputContainer)) @@ -517,7 +556,6 @@ private void TestPreTrainedModels(string opsetDir, string modelName) foreach (var result in resultCollection) { Assert.True(session.OutputMetadata.ContainsKey(result.Name)); - var outputMeta = session.OutputMetadata[result.Name]; NamedOnnxValue outputValue = null; foreach (var o in outputContainer) { @@ -527,17 +565,32 @@ private void TestPreTrainedModels(string opsetDir, string modelName) break; } } - if (outputValue == null) - { - outputValue = outputContainer.First(); // in case the output data file does not contain the name - } - if (outputMeta.OnnxValueType == OnnxValueType.ONNX_TYPE_TENSOR) // Only Dense tensors now + + Assert.NotNull(outputValue); + + var outputMeta = session.OutputMetadata[result.Name]; + if (outputMeta.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL) { - VerifyTensorResults(outputMeta.ElementDataType, result, outputValue); + outputMeta = outputMeta.AsOptionalMetadata().ElementMeta; } - else + + Assert.Equal(outputValue.ValueType, outputMeta.OnnxValueType); + + switch(outputValue.ValueType) { - Assert.True(false, "TestPreTrainedModels cannot handle Onnxtype: " + outputMeta.OnnxValueType.ToString()); + case OnnxValueType.ONNX_TYPE_TENSOR: // Only Dense tensors now + { + VerifyTensorResults(outputMeta.ElementDataType, result, outputValue); + } + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + VerifySequenceResults(result, outputValue, outputMeta); + } + break; + default: + Assert.True(false, $"TestPreTrainedModels cannot handle Onnxtype: {outputValue.ValueType}"); + break; } } } @@ -562,7 +615,36 @@ private void TestPreTrainedModels(string opsetDir, string modelName) } } - private void VerifyTensorResults(TensorElementType elementType, DisposableNamedOnnxValue result, NamedOnnxValue outputValue) + private void VerifySequenceResults(NamedOnnxValue result, NamedOnnxValue expectedValue, NodeMetadata metaData) + { + var meta = metaData.AsSequenceMetadata(); + var resultSequence = result.AsEnumerable(); + var expectedSequence = expectedValue.AsEnumerable(); + Assert.Equal(resultSequence.Count(), expectedSequence.Count()); + + foreach (var (resultItem, expectedItem) in resultSequence.Zip(expectedSequence, (r, e) => (r, e))) + { + Assert.Equal(resultItem.ValueType, expectedItem.ValueType); + Assert.Equal(resultItem.ValueType, meta.ElementMeta.OnnxValueType); + switch (resultItem.ValueType) + { + case OnnxValueType.ONNX_TYPE_TENSOR: + VerifyTensorResults(meta.ElementMeta.ElementDataType, resultItem, expectedItem); + break; + case OnnxValueType.ONNX_TYPE_SEQUENCE: + { + VerifySequenceResults(resultItem, expectedItem, meta.ElementMeta); + } + break; + default: + Assert.True(false, "VerifySequenceResults cannot handle Onnxtype: " + resultItem.ValueType.ToString()); + break; + } + Assert.Equal(resultItem.AsTensor(), expectedItem.AsTensor(), new FloatComparer()); + } + } + + private void VerifyTensorResults(TensorElementType elementType, NamedOnnxValue result, NamedOnnxValue outputValue) { switch (elementType) { diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 833e2da3ca42..144df446281b 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -37,16 +37,16 @@ OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { } OrtTypeInfo::OrtTypeInfo(std::unique_ptr map_type_info1) noexcept - : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info1)) {} + : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info1)) {} OrtTypeInfo::OrtTypeInfo(std::unique_ptr sequence_type_info1) noexcept - : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info1)) {} + : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info1)) {} OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info1) noexcept - : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info1)) {} + : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info1)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type1, std::unique_ptr data1) noexcept - : type(type1), data(std::move(data1)) { + : type(type1), data(std::move(data1)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -59,7 +59,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, - _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { + _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; return nullptr; @@ -67,7 +67,7 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtType } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { + _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info.get() : nullptr; return nullptr; @@ -75,7 +75,7 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info.get() : nullptr; return nullptr; @@ -83,9 +83,9 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeI } ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out) { + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out) { API_IMPL_BEGIN - *out = (type_info->type != ONNX_TYPE_OPTIONAL) ? type_info->optional_type_info.get() : nullptr; + *out = (type_info->type == ONNX_TYPE_OPTIONAL) ? type_info->optional_type_info.get() : nullptr; return nullptr; API_IMPL_END } @@ -104,9 +104,12 @@ ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { } OrtTypeInfo::Ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { + + Ptr result = MakePtr(ONNX_TYPE_UNKNOWN); + onnxruntime::MLDataType type = value.Type(); if (type == nullptr) { - return MakePtr(ONNX_TYPE_UNKNOWN); + return result; } // GetType and GetType do not have TypeProto populated because they return a static @@ -155,36 +158,37 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { - return MakePtr(ONNX_TYPE_OPAQUE); + result = MakePtr(ONNX_TYPE_OPAQUE); } break; case on::TypeProto::kMapType: { #if !defined(DISABLE_ML_OPS) auto map_type_info = OrtMapTypeInfo::FromTypeProto(*type_proto); - return MakePtr(std::move(map_type_info)); - } break; + result = MakePtr(std::move(map_type_info)); #else ORT_NOT_IMPLEMENTED("Map types are not supported in this build"); #endif + } break; case on::TypeProto::kSequenceType: { auto seq_info = OrtSequenceTypeInfo::FromTypeProto(*type_proto); - return MakePtr(std::move(seq_info)); + result = MakePtr(std::move(seq_info)); } break; // Real Tensor support -#if !defined(DISABLE_SPARSE_TENSORS) case on::TypeProto::kSparseTensorType: +#if !defined(DISABLE_SPARSE_TENSORS) [[fallthrough]]; #else ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); + break; #endif - case on::TypeProto::kTensorType: { + case on::TypeProto::kTensorType: ORT_THROW("Tensor types should have been handled already"); - } break; + break; default: - // NOT_IMPLEMENTED + ORT_NOT_IMPLEMENTED("This OrtValue is neither Tensor, SparseTensor, Map or Sequence type"); break; } } - ORT_NOT_IMPLEMENTED("This OrtValue is neither Tensor, SparseTensor, Map or Sequence type"); + return result; } const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { @@ -193,6 +197,8 @@ const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { } OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& input) { + Ptr result; + auto value_case = input.value_case(); switch (value_case) { case on::TypeProto::kSparseTensorType: @@ -200,6 +206,7 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& inp [[fallthrough]]; #else ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); + break; #endif case on::TypeProto::kTensorType: { ONNXType ten_type = ONNX_TYPE_UNKNOWN; @@ -251,46 +258,42 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& inp type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input); } - auto type_info = MakePtr(ten_type, std::move(type_shape)); - type_info->denotation = input.denotation(); - return type_info; + result = MakePtr(ten_type, std::move(type_shape)); + result->denotation = input.denotation(); } break; case on::TypeProto::kSequenceType: { auto sequence_type_info = OrtSequenceTypeInfo::FromTypeProto(input); - auto type_info = MakePtr(std::move(sequence_type_info)); - type_info->denotation = input.denotation(); - return type_info; + result = MakePtr(std::move(sequence_type_info)); + result->denotation = input.denotation(); } break; #if !defined(DISABLE_ML_OPS) case on::TypeProto::kMapType: { auto map_type_info = OrtMapTypeInfo::FromTypeProto(input); - auto type_info = MakePtr(std::move(map_type_info)); - type_info->denotation = input.denotation(); - return type_info; + result = MakePtr(std::move(map_type_info)); + result->denotation = input.denotation(); } break; #endif case on::TypeProto::kOptionalType: { auto optional_type_info = OrtOptionalTypeInfo::FromTypeProto(input); - auto type_info = MakePtr(std::move(optional_type_info)); - type_info->denotation = input.denotation(); - return type_info; + result = MakePtr(std::move(optional_type_info)); + result->denotation = input.denotation(); } break; case on::TypeProto::kOpaqueType: { - auto type_info = MakePtr(ONNX_TYPE_OPAQUE); - type_info->denotation = input.denotation(); - return type_info; + result = MakePtr(ONNX_TYPE_OPAQUE); + result->denotation = input.denotation(); } break; case on::TypeProto::VALUE_NOT_SET: ORT_THROW("This TypeProto does not have ValueCase set"); break; default: - // Not implemented + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); break; } - ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); + return result; } OrtTypeInfo::Ptr OrtTypeInfo::Clone() const { + Ptr result; switch (type) { case ONNX_TYPE_SPARSETENSOR: #if !defined(DISABLE_SPARSE_TENSORS) @@ -303,37 +306,33 @@ OrtTypeInfo::Ptr OrtTypeInfo::Clone() const { if (data) { info = data->Clone(); } - auto type_info = MakePtr(type, std::move(info)); - type_info->denotation = denotation; - return type_info; - } + result = MakePtr(type, std::move(info)); + result->denotation = denotation; + } break; case ONNX_TYPE_SEQUENCE: { auto seq_clone = sequence_type_info->Clone(); - auto type_info = MakePtr(std::move(seq_clone)); - type_info->denotation = denotation; - return type_info; - } + result = MakePtr(std::move(seq_clone)); + result->denotation = denotation; + } break; case ONNX_TYPE_MAP: { auto map_clone = map_type_info->Clone(); - auto type_info = MakePtr(std::move(map_clone)); - type_info->denotation = denotation; - return type_info; - } + result = MakePtr(std::move(map_clone)); + result->denotation = denotation; + } break; case ONNX_TYPE_OPTIONAL: { auto opt_clone = optional_type_info->Clone(); - auto type_info = MakePtr(std::move(opt_clone)); - type_info->denotation = denotation; - return type_info; - } + result = MakePtr(std::move(opt_clone)); + result->denotation = denotation; + } break; case ONNX_TYPE_OPAQUE: { - auto type_info = MakePtr(type); - type_info->denotation = denotation; - return type_info; - } + result = MakePtr(type); + result->denotation = denotation; + } break; default: - // Not implemented + ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); break; } - ORT_NOT_IMPLEMENTED("The type is not tensor, sparse tensor, sequence, map or optional type"); + + return result; } diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 207df877ff66..f71b485fd8a7 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -3,7 +3,6 @@ #pragma once -#include #include #include #include diff --git a/tools/ci_build/github/Doxyfile_csharp.cfg b/tools/ci_build/github/Doxyfile_csharp.cfg index 78fb4d5e9af5..dccc15ed1137 100644 --- a/tools/ci_build/github/Doxyfile_csharp.cfg +++ b/tools/ci_build/github/Doxyfile_csharp.cfg @@ -1,20 +1,146 @@ -## Onnxruntime C# API Doxygen configuration file -# Doxyfile 1.8.20 +# Doxyfile 1.9.4 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +## Onnxruntime C# API Doxygen configuration file + DOXYFILE_ENCODING = UTF-8 -PROJECT_NAME = "Onnxruntime" + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = Onnxruntime + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + OUTPUT_DIRECTORY = $(ORT_DOXY_OUT)\csharp_dox + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. +# The default value is: NO. + CREATE_SUBDIRS = NO + +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# numer of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. +# The default value is: English. + OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + ABBREVIATE_BRIEF = "The $name class" \ "The $name widget" \ "The $name file" \ @@ -26,313 +152,2539 @@ ABBREVIATE_BRIEF = "The $name class" \ a \ an \ the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + MULTILINE_CPP_IS_BRIEF = NO + +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + PYTHON_DOCSTRING = YES + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:^^" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) + ALIASES = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. + EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + TOC_INCLUDE_HEADINGS = 5 + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + LOOKUP_CACHE_SIZE = 0 + +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + NUM_PROC_THREADS = 1 + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + EXTRACT_PRIVATE = NO + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + EXTRACT_PRIV_VIRTUAL = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + EXTRACT_ANON_NSPACES = NO + +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + INTERNAL_DOCS = NO + +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# The default value is: system dependent. + CASE_SENSE_NAMES = NO + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + SHOW_USED_FILES = NO + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + FILE_VERSION_FILTER = "git -C $(ORT_DOXY_SRC) log -n 1 --format=%h -- afile" + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + CITE_BIB_FILES = + #--------------------------------------------------------------------------- # Configuration options related to warning and progress messages #--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. +# The default value is: YES. + WARN_IF_DOC_ERROR = YES + +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC +# The default value is: NO. + WARN_NO_PARAMDOC = YES + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# Possible values are: NO, YES and FAIL_ON_WARNINGS. +# The default value is: NO. + WARN_AS_ERROR = YES + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT +# The default value is: $file:$line: $text. + WARN_FORMAT = "$file:$line: $text" -WARN_LOGFILE = + +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). + +WARN_LOGFILE = + #--------------------------------------------------------------------------- # Configuration options related to the input files #--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + INPUT = $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime \ - $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime\Tensors + $(ORT_DOXY_SRC)\csharp\src\Microsoft.ML.OnnxRuntime\Tensors + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# The default value is: UTF-8. + INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. + FILE_PATTERNS = *.cs + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + RECURSIVE = NO + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + EXCLUDE_PATTERNS = Native*.cs + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# ANamespace::AClass, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + EXCLUDE_SYMBOLS = -EXAMPLE_PATH = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + USE_MDFILE_AS_MAINPAGE = + #--------------------------------------------------------------------------- # Configuration options related to source browsing #--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + VERBATIM_HEADERS = YES + +# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the +# clang parser (see: +# http://clang.llvm.org/) for more accurate parsing at the cost of reduced +# performance. This can be particularly helpful with template rich C++ code for +# which doxygen's built-in parser lacks the necessary type information. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. +# The default value is: NO. + CLANG_ASSISTED_PARSING = NO + +# If the CLANG_ASSISTED_PARSING tag is set to YES and the CLANG_ADD_INC_PATHS +# tag is set to YES then doxygen will add the directory of each input to the +# include path. +# The default value is: YES. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_ADD_INC_PATHS = YES + +# If clang assisted parsing is enabled you can provide the compiler with command +# line options that you would normally use when invoking the compiler. Note that +# the include paths will already be set by doxygen for the files and directories +# specified with INPUT and INCLUDE_PATH. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + CLANG_OPTIONS = + +# If clang assisted parsing is enabled you can provide the clang parser with the +# path to the directory containing a file called compile_commands.json. This +# file is the compilation database (see: +# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) containing the +# options used when the source files were built. This is equivalent to +# specifying the -p option to a clang tool, such as clang-check. These options +# will then be passed to the parser. Any options specified with CLANG_OPTIONS +# will be added as well. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. + CLANG_DATABASE_PATH = + #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + ALPHABETICAL_INDEX = YES + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + IGNORE_PREFIX = + #--------------------------------------------------------------------------- # Configuration options related to the HTML output #--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use gray-scales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the main .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + GENERATE_TREEVIEW = NO + +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + EXT_LINKS_IN_WINDOW = NO + +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + HTML_FORMULA_FORMAT = png + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + FORMULA_TRANSPARENT = YES + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + USE_MATHJAX = NO + +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_RELPATH = https://cdn.jsdelivr.net/npm/mathjax@2 + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , / - public byte[] ZeroTerminatedUtf8Name { get; set; } + internal byte[] ZeroTerminatedUtf8Name { get; set; } /// /// Tensor shape valid only if this is a Tensor. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs index 9334313e87aa..9854d1940f67 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs @@ -95,7 +95,7 @@ internal NamedOnnxValue(string name, Object value, MapHelper helper) /// Onnx Value Type if known. In general, NamedOnnxValue is able to contain /// arbitrary objects. /// - public OnnxValueType ValueType { get; } + public OnnxValueType ValueType { get; internal set; } /// /// This is a factory method that instantiates NamedOnnxValue @@ -203,7 +203,9 @@ internal virtual OrtValue InputToOrtValue(NodeMetadata metadata, out IDisposable /// /// Produces an output value for outputs. This produces an output value /// only for tensors or optional types that can contain a tensor. - /// For all others we return a null, letting ORT to create an output value. + /// For all other Onnx value types, this method throws. Use Run() overloads + /// that return DisposableNamedOnnxValue to get access to all Onnx value types + /// that may be returned as output. /// /// /// @@ -229,8 +231,11 @@ internal virtual OrtValue OutputToOrtValue(NodeMetadata metadata, out IDisposabl return projection.Value; } } - memoryOwner = null; - return null; + + throw new OnnxRuntimeException(ErrorCode.NotImplemented, + $"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." + + $" Only tensors can be pre-allocated for outputs " + + $" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output."); } /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index a385f0a24985..868cf00ae334 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -20,7 +20,7 @@ public enum OnnxValueType ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor - ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (but unknown) + ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKOWN) } /// From c83e7de3a6ebe15258a2bdf5efe7a25da258f54c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 7 Apr 2023 11:43:04 -0700 Subject: [PATCH 5/9] Address library review comments --- .../core/session/onnxruntime_c_api.h | 13 ++- .../framework/onnxruntime_map_type_info.cc | 90 +++++++++++++------ .../framework/onnxruntime_map_type_info.h | 11 ++- .../onnxruntime_optional_type_info.cc | 27 +++--- .../onnxruntime_optional_type_info.h | 12 ++- .../onnxruntime_sequence_type_info.cc | 23 ++--- .../onnxruntime_sequence_type_info.h | 10 +-- .../core/framework/onnxruntime_typeinfo.cc | 63 +++++++------ .../core/framework/onnxruntime_typeinfo.h | 20 ++--- .../core/framework/tensor_type_and_shape.cc | 56 +++++++----- .../core/framework/tensor_type_and_shape.h | 33 ++++--- 11 files changed, 203 insertions(+), 155 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f4975f0047d8..3cc959c8d704 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -564,7 +564,14 @@ typedef struct OrtMIGraphXProviderOptions { */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, context{}, enable_opencl_throttling{}, enable_dynamic_shapes{} {} + OrtOpenVINOProviderOptions() : device_type{}, + enable_vpu_fast_compile{}, + device_id{}, + num_of_threads{}, + cache_dir{}, + context{}, + enable_opencl_throttling{}, + enable_dynamic_shapes{} {} #endif /** \brief Device type string * @@ -4087,8 +4094,8 @@ struct OrtApi { * \since Version 1.15. */ ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); - - /** \brief Get Optional Type information from an ::OrtTypeInfo + + /** \brief Get Optional Type information from an ::OrtTypeInfo * * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index bcf925dce48e..4e9acdcf5ef0 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept - : map_key_type_(map_key_type), map_value_type_(std::move(map_value_type)) { +OrtMapTypeInfo::OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, + std::unique_ptr map_value_type) noexcept + : map_key_type_(map_key_type), map_value_type_(std::move(map_value_type)) { } OrtMapTypeInfo::~OrtMapTypeInfo() = default; @@ -16,30 +18,62 @@ static ONNXTensorElementDataType ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { using TensorType = ONNX_NAMESPACE::TensorProto_DataType; switch (data_type) { - case TensorType::TensorProto_DataType_BOOL: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; } - case TensorType::TensorProto_DataType_STRING: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; } // maps to c++ type std::string - case TensorType::TensorProto_DataType_FLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; } // maps to c type float16 - case TensorType::TensorProto_DataType_FLOAT: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // maps to c type float - case TensorType::TensorProto_DataType_DOUBLE: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; } // maps to c type double - case TensorType::TensorProto_DataType_INT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; } // maps to c type int8_t - case TensorType::TensorProto_DataType_INT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; } // maps to c type int16_t - case TensorType::TensorProto_DataType_INT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } // maps to c type int32_t - case TensorType::TensorProto_DataType_INT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } // maps to c type int64_t - case TensorType::TensorProto_DataType_UINT8: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; } // maps to c type uint8_t - case TensorType::TensorProto_DataType_UINT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; } // maps to c type uint16_t - case TensorType::TensorProto_DataType_UINT32: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; } // maps to c type uint32_t - case TensorType::TensorProto_DataType_UINT64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; } // maps to c type uint64_t - case TensorType::TensorProto_DataType_COMPLEX64: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; } // complex with float32 real and imaginary components - case TensorType::TensorProto_DataType_COMPLEX128: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; } // complex with float64 real and imaginary components - case TensorType::TensorProto_DataType_BFLOAT16: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; } // Non-IEEE floating-point format based on IEEE754 single-precision - default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } + case TensorType::TensorProto_DataType_BOOL: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + } + case TensorType::TensorProto_DataType_STRING: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; + } // maps to c++ type std::string + case TensorType::TensorProto_DataType_FLOAT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } // maps to c type float16 + case TensorType::TensorProto_DataType_FLOAT: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } // maps to c type float + case TensorType::TensorProto_DataType_DOUBLE: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } // maps to c type double + case TensorType::TensorProto_DataType_INT8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } // maps to c type int8_t + case TensorType::TensorProto_DataType_INT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + } // maps to c type int16_t + case TensorType::TensorProto_DataType_INT32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } // maps to c type int32_t + case TensorType::TensorProto_DataType_INT64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } // maps to c type int64_t + case TensorType::TensorProto_DataType_UINT8: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } // maps to c type uint8_t + case TensorType::TensorProto_DataType_UINT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + } // maps to c type uint16_t + case TensorType::TensorProto_DataType_UINT32: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + } // maps to c type uint32_t + case TensorType::TensorProto_DataType_UINT64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + } // maps to c type uint64_t + case TensorType::TensorProto_DataType_COMPLEX64: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64; + } // complex with float32 real and imaginary components + case TensorType::TensorProto_DataType_COMPLEX128: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128; + } // complex with float64 real and imaginary components + case TensorType::TensorProto_DataType_BFLOAT16: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; + } // Non-IEEE floating-point format based on IEEE754 single-precision + default: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } } } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif -OrtMapTypeInfo::Ptr OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { +std::unique_ptr OrtMapTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) { ORT_THROW("type_proto is not of type map!"); @@ -47,7 +81,8 @@ OrtMapTypeInfo::Ptr OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProt // Get the key type of the map const auto& type_proto_map = type_proto.map_type(); - const auto map_key_type = ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); + const auto map_key_type = ToONNXTensorElementDataType( + ONNX_NAMESPACE::TensorProto_DataType(type_proto_map.key_type())); // Get the value type of the map auto map_value_type_info = OrtTypeInfo::FromTypeProto(type_proto_map.value_type()); @@ -55,7 +90,7 @@ OrtMapTypeInfo::Ptr OrtMapTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProt return std::make_unique(map_key_type, std::move(map_value_type_info)); } -OrtMapTypeInfo::Ptr OrtMapTypeInfo::Clone() const { +std::unique_ptr OrtMapTypeInfo::Clone() const { auto map_value_type_copy = map_value_type_->Clone(); return std::make_unique(map_key_type_, std::move(map_value_type_copy)); } @@ -69,7 +104,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_ API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, + _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN auto clone = map_type_info->map_value_type_->Clone(); *out = clone.release(); @@ -78,5 +114,5 @@ ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, _In_ const OrtMapTypeInfo* map_typ } ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { - OrtMapTypeInfo::Ptr p(ptr); + std::unique_ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h index 9a72be3db490..1d47d51ddf22 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -10,23 +10,22 @@ namespace ONNX_NAMESPACE { class TypeProto; } +struct OrtTypeInfo; + struct OrtMapTypeInfo { public: - using Ptr = std::unique_ptr; - ONNXTensorElementDataType map_key_type_ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; std::unique_ptr map_value_type_; - static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); - Ptr Clone() const; + std::unique_ptr Clone() const; OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept; ~OrtMapTypeInfo(); - private: - OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; + OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; }; diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.cc b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc index 0ad5fc1a9ca2..d0eda0881987 100644 --- a/onnxruntime/core/framework/onnxruntime_optional_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.cc @@ -6,18 +6,17 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtOptionalTypeInfo::OrtOptionalTypeInfo(OrtTypeInfo::Ptr contained_type) noexcept +OrtOptionalTypeInfo::OrtOptionalTypeInfo(std::unique_ptr contained_type) noexcept : contained_type_(std::move(contained_type)) { } OrtOptionalTypeInfo::~OrtOptionalTypeInfo() = default; -OrtOptionalTypeInfo::Ptr OrtOptionalTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { +std::unique_ptr OrtOptionalTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { const auto value_case = type_proto.value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kOptionalType) { - ORT_THROW("type_proto is not of optional type"); - } + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kOptionalType, "type_proto is not of optional type"); const auto& type_proto_optional = type_proto.optional_type(); auto contained_type_info = OrtTypeInfo::FromTypeProto(type_proto_optional.elem_type()); @@ -25,16 +24,16 @@ OrtOptionalTypeInfo::Ptr OrtOptionalTypeInfo::FromTypeProto(const ONNX_NAMESPACE return std::make_unique(std::move(contained_type_info)); } -OrtOptionalTypeInfo::Ptr OrtOptionalTypeInfo::Clone() const { +std::unique_ptr OrtOptionalTypeInfo::Clone() const { auto contained_type_copy = contained_type_->Clone(); return std::make_unique(std::move(contained_type_copy)); } - ORT_API_STATUS_IMPL(OrtApis::GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, - _Outptr_ OrtTypeInfo** out) { - API_IMPL_BEGIN - auto type_info = optional_type_info->contained_type_->Clone(); - *out = type_info.release(); - return nullptr; - API_IMPL_END - } +ORT_API_STATUS_IMPL(OrtApis::GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out) { + API_IMPL_BEGIN + auto type_info = optional_type_info->contained_type_->Clone(); + *out = type_info.release(); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/framework/onnxruntime_optional_type_info.h b/onnxruntime/core/framework/onnxruntime_optional_type_info.h index 9a44839f110f..561d055689b5 100644 --- a/onnxruntime/core/framework/onnxruntime_optional_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_optional_type_info.h @@ -4,7 +4,7 @@ #include -#include "core/framework/onnxruntime_typeinfo.h" +struct OrtTypeInfo; namespace ONNX_NAMESPACE { class TypeProto; @@ -12,16 +12,14 @@ class TypeProto; struct OrtOptionalTypeInfo { - using Ptr = std::unique_ptr; - - explicit OrtOptionalTypeInfo(OrtTypeInfo::Ptr contained_type) noexcept; + explicit OrtOptionalTypeInfo(std::unique_ptr contained_type) noexcept; ~OrtOptionalTypeInfo(); - OrtTypeInfo::Ptr contained_type_; + std::unique_ptr contained_type_; - Ptr Clone() const; + std::unique_ptr Clone() const; - static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); OrtOptionalTypeInfo(const OrtOptionalTypeInfo& other) = delete; OrtOptionalTypeInfo& operator=(const OrtOptionalTypeInfo& other) = delete; diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc index 4022aa6e4f1a..3f1d85261091 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -6,23 +6,17 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -OrtSequenceTypeInfo::OrtSequenceTypeInfo(OrtTypeInfo::Ptr sequence_key_type) noexcept - : sequence_key_type_(std::move(sequence_key_type)) { +OrtSequenceTypeInfo::OrtSequenceTypeInfo(std::unique_ptr sequence_key_type) noexcept + : sequence_key_type_(std::move(sequence_key_type)) { } OrtSequenceTypeInfo::~OrtSequenceTypeInfo() = default; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif - - OrtSequenceTypeInfo::Ptr OrtSequenceTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& type_proto) { - +std::unique_ptr OrtSequenceTypeInfo::FromTypeProto( + const ONNX_NAMESPACE::TypeProto& type_proto) { const auto value_case = type_proto.value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kSequenceType) { - ORT_THROW("type_proto is not of type sequence!"); - } + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kSequenceType, "type_proto is not of type sequence!"); const auto& type_proto_sequence = type_proto.sequence_type(); auto key_type_info = OrtTypeInfo::FromTypeProto(type_proto_sequence.elem_type()); @@ -30,12 +24,13 @@ OrtSequenceTypeInfo::~OrtSequenceTypeInfo() = default; return std::make_unique(std::move(key_type_info)); } -OrtSequenceTypeInfo::Ptr OrtSequenceTypeInfo::Clone() const { +std::unique_ptr OrtSequenceTypeInfo::Clone() const { auto key_type_copy = sequence_key_type_->Clone(); return std::make_unique(std::move(key_type_copy)); } -ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, +ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, + _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN auto key_type_copy = sequence_type_info->sequence_key_type_->Clone(); @@ -45,5 +40,5 @@ ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, _In_ const OrtSequenceTypeI } ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo* ptr) { - OrtSequenceTypeInfo::Ptr p(ptr); + std::unique_ptr p(ptr); } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h index dc016f44c380..d1d1412b92ce 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.h @@ -13,16 +13,14 @@ class TypeProto; struct OrtSequenceTypeInfo { public: - using Ptr = std::unique_ptr; - - explicit OrtSequenceTypeInfo(OrtTypeInfo::Ptr sequence_key_type) noexcept; + explicit OrtSequenceTypeInfo(std::unique_ptr sequence_key_type) noexcept; ~OrtSequenceTypeInfo(); - OrtTypeInfo::Ptr sequence_key_type_; + std::unique_ptr sequence_key_type_; - Ptr Clone() const; + std::unique_ptr Clone() const; - static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); OrtSequenceTypeInfo(const OrtSequenceTypeInfo& other) = delete; OrtSequenceTypeInfo& operator=(const OrtSequenceTypeInfo& other) = delete; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 144df446281b..678e7e6e7823 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -29,24 +29,20 @@ using onnxruntime::TensorShape; namespace on = ONNX_NAMESPACE; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif - -OrtTypeInfo::OrtTypeInfo(ONNXType type1) noexcept : type(type1) { +OrtTypeInfo::OrtTypeInfo(ONNXType type) noexcept : type(type) { } -OrtTypeInfo::OrtTypeInfo(std::unique_ptr map_type_info1) noexcept - : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info1)) {} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr map_type_info) noexcept + : type(ONNX_TYPE_MAP), map_type_info(std::move(map_type_info)) {} -OrtTypeInfo::OrtTypeInfo(std::unique_ptr sequence_type_info1) noexcept - : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info1)) {} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr sequence_type_info) noexcept + : type(ONNX_TYPE_SEQUENCE), sequence_type_info(std::move(sequence_type_info)) {} -OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info1) noexcept - : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info1)) {} +OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info) noexcept + : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} -OrtTypeInfo::OrtTypeInfo(ONNXType type1, std::unique_ptr data1) noexcept - : type(type1), data(std::move(data1)) { +OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept + : type(type), data(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -103,9 +99,8 @@ ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } -OrtTypeInfo::Ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { - - Ptr result = MakePtr(ONNX_TYPE_UNKNOWN); +std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { + auto result = MakePtr(ONNX_TYPE_UNKNOWN); onnxruntime::MLDataType type = value.Type(); if (type == nullptr) { @@ -142,15 +137,13 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { if (type->IsTensorSequenceType()) { const auto* tensor_data_type = value.Get().DataType(); - if (tensor_data_type != nullptr) { - TensorShape void_shape = {}; - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); - auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); - auto sequence_type_info = std::make_unique(std::move(type_info)); - return MakePtr(std::move(sequence_type_info)); - } else { - ORT_THROW("OrtValue is TensorSequence type but has no element Tensor DataType."); - } + ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType."); + + TensorShape void_shape = {}; + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); + auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); + auto sequence_type_info = std::make_unique(std::move(type_info)); + return MakePtr(std::move(sequence_type_info)); } const auto* type_proto = type->GetTypeProto(); @@ -196,8 +189,8 @@ const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { return tensor_type->GetElementType(); } -OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& input) { - Ptr result; +std::unique_ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& input) { + std::unique_ptr result; auto value_case = input.value_case(); switch (value_case) { @@ -228,10 +221,12 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& inp if (onnxruntime::utils::HasShape(*sparse_type)) { sp = &sparse_type->shape(); } +#else + ORT_NOT_IMPLEMENTED("SparseTensor types are not supported in this build"); #endif } - OrtTensorTypeAndShapeInfo::Ptr type_shape; + std::unique_ptr type_shape; if (sp != nullptr) { const on::TensorShapeProto& s = *sp; std::vector dims(s.dim_size()); @@ -266,13 +261,15 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& inp result = MakePtr(std::move(sequence_type_info)); result->denotation = input.denotation(); } break; -#if !defined(DISABLE_ML_OPS) case on::TypeProto::kMapType: { +#if !defined(DISABLE_ML_OPS) auto map_type_info = OrtMapTypeInfo::FromTypeProto(input); result = MakePtr(std::move(map_type_info)); result->denotation = input.denotation(); - } break; +#else + ORT_NOT_IMPLEMENTED("Map types are not supported in this build"); #endif + } break; case on::TypeProto::kOptionalType: { auto optional_type_info = OrtOptionalTypeInfo::FromTypeProto(input); result = MakePtr(std::move(optional_type_info)); @@ -292,8 +289,8 @@ OrtTypeInfo::Ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto& inp return result; } -OrtTypeInfo::Ptr OrtTypeInfo::Clone() const { - Ptr result; +std::unique_ptr OrtTypeInfo::Clone() const { + std::unique_ptr result; switch (type) { case ONNX_TYPE_SPARSETENSOR: #if !defined(DISABLE_SPARSE_TENSORS) @@ -302,7 +299,7 @@ OrtTypeInfo::Ptr OrtTypeInfo::Clone() const { ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif case ONNX_TYPE_TENSOR: { - OrtTensorTypeAndShapeInfo::Ptr info; + std::unique_ptr info; if (data) { info = data->Clone(); } diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index f71b485fd8a7..06b0b8d989f5 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -28,8 +28,6 @@ struct OrtTensorTypeAndShapeInfo; * This class is mainly for the C API */ struct OrtTypeInfo { - // Provide default construction - using Ptr = std::unique_ptr; ONNXType type; std::string denotation; @@ -39,21 +37,21 @@ struct OrtTypeInfo { std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; - Ptr Clone() const; + std::unique_ptr Clone() const; - static Ptr FromOrtValue(const OrtValue& value); - static Ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); + static std::unique_ptr FromOrtValue(const OrtValue& value); + static std::unique_ptr FromTypeProto(const ONNX_NAMESPACE::TypeProto&); static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); - explicit OrtTypeInfo(ONNXType type1) noexcept; + explicit OrtTypeInfo(ONNXType type) noexcept; - explicit OrtTypeInfo(std::unique_ptr map_type_info1) noexcept; + explicit OrtTypeInfo(std::unique_ptr map_type_info) noexcept; - OrtTypeInfo(ONNXType type1, std::unique_ptr data1) noexcept; + OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept; - explicit OrtTypeInfo(std::unique_ptr sequence_type_info1) noexcept; + explicit OrtTypeInfo(std::unique_ptr sequence_type_info) noexcept; - explicit OrtTypeInfo(std::unique_ptr optional_type_info1) noexcept; + explicit OrtTypeInfo(std::unique_ptr optional_type_info) noexcept; OrtTypeInfo(const OrtTypeInfo&) = delete; @@ -62,7 +60,7 @@ struct OrtTypeInfo { ~OrtTypeInfo(); template - static Ptr MakePtr(Args... args) { + static std::unique_ptr MakePtr(Args... args) { return std::make_unique(std::forward(args)...); } diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index f2a4f457b68a..1b73ed1d837b 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -27,11 +27,11 @@ using onnxruntime::SparseTensor; using onnxruntime::narrow; using onnxruntime::Tensor; +OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo() = default; OrtTensorTypeAndShapeInfo::~OrtTensorTypeAndShapeInfo() = default; +OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = default; +OrtTensorTypeAndShapeInfo& OrtTensorTypeAndShapeInfo::operator=(const OrtTensorTypeAndShapeInfo& other) = default; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN *out = std::make_unique().release(); @@ -51,29 +51,34 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, _In_ const int64_t* dim_values, size_t dim_count) { +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, + _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetTensorElementType, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ ONNXTensorElementDataType* out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorElementType, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ ONNXTensorElementDataType* out) { *out = info->type; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ size_t* out) { *out = info->shape.NumDimensions(); return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { +ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ int64_t* dim_values, size_t dim_values_length) { info->shape.CopyDims(dim_values, dim_values_length); return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, +ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, + _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char** names, size_t dim_params_length) { for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) { names[idx] = info->dim_params[idx].c_str(); @@ -82,7 +87,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorT return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, + _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { API_IMPL_BEGIN *out = SafeInt{this_ptr->shape.Size()}; return nullptr; @@ -154,8 +160,10 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( return TensorDataTypeToOnnxRuntimeTensorElementDataType(prim_type->GetDataType()); } -OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, onnxruntime::TensorShape shape, - const std::vector* dim_params) { +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( + ONNXTensorElementDataType type, + onnxruntime::TensorShape shape, + const std::vector* dim_params) { auto type_and_shape = std::make_unique(); type_and_shape->type = type; type_and_shape->shape = std::move(shape); @@ -171,8 +179,9 @@ OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeH return type_and_shape; } -OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(onnxruntime::TensorShape shape, - const onnxruntime::DataTypeImpl& tensor_data_type) { +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { ORT_NOT_IMPLEMENTED("Tensor type is undefined"); @@ -180,8 +189,10 @@ OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr); } -OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(onnxruntime::TensorShape shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto& type_proto) { +std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); @@ -195,10 +206,12 @@ OrtTensorTypeAndShapeInfo::Ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params); } -ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, + _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN if (!v->IsAllocated()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "the ort_value must contain a constructed tensor or sparse tensor"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "the ort_value must contain a constructed tensor or sparse tensor"); } if (v->IsTensor() || v->IsSparseTensor()) { const onnxruntime::TensorShape* shape = nullptr; @@ -207,15 +220,17 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out const Tensor& tensor = v->Get(); shape = &tensor.Shape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); *out = ptr.release(); } else { #if !defined(DISABLE_SPARSE_TENSORS) const SparseTensor& tensor = v->Get(); shape = &tensor.DenseShape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); *out = ptr.release(); +#else + ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); #endif } } else { @@ -316,7 +331,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp * \param value * \return The returned value should be freed by OrtReleaseTypeInfo after use */ -ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result_maybenull_ struct OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, + _In_ const OrtValue* v, _Outptr_result_maybenull_ struct OrtTypeInfo** out) { API_IMPL_BEGIN // TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns // ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 283c1de37bda..149d8c156821 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -19,7 +19,6 @@ class DataTypeImpl; struct OrtTensorTypeAndShapeInfo { public: - using Ptr = std::unique_ptr; ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; onnxruntime::TensorShape shape; @@ -27,25 +26,31 @@ struct OrtTensorTypeAndShapeInfo { // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs std::vector dim_params; - OrtTensorTypeAndShapeInfo() = default; + OrtTensorTypeAndShapeInfo(); ~OrtTensorTypeAndShapeInfo(); - Ptr Clone() const { - return std::make_unique(*this); - } + // Utils + static std::unique_ptr GetTensorShapeAndTypeHelper( + ONNXTensorElementDataType type, + onnxruntime::TensorShape shape, + const std::vector* dim_params); - OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = default; - OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other) = default; + static std::unique_ptr GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const onnxruntime::DataTypeImpl& tensor_data_type); - // Utils - static Ptr GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, onnxruntime::TensorShape shape, - const std::vector* dim_params); + static std::unique_ptr GetTensorShapeAndType( + onnxruntime::TensorShape shape, + const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto&); - static Ptr GetTensorShapeAndType(onnxruntime::TensorShape shape, - const onnxruntime::DataTypeImpl& tensor_data_type); + std::unique_ptr Clone() const { + return std::make_unique(*this); + } - static Ptr GetTensorShapeAndType(onnxruntime::TensorShape shape, const std::vector* dim_params, - const ONNX_NAMESPACE::TypeProto&); + // Copy ops are public because std::make_unique above requires them to be accessible + OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other); + OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other); }; constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(int32_t dtype); From e0de4eae43d4d76f34d0cf61931787dcad8cc7b7 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 7 Apr 2023 15:08:29 -0700 Subject: [PATCH 6/9] Add OrtTypeInfo tests --- .../test/framework/model_builder_utils.h | 23 ++++++ onnxruntime/test/framework/type_info_test.cc | 82 +++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 onnxruntime/test/framework/type_info_test.cc diff --git a/onnxruntime/test/framework/model_builder_utils.h b/onnxruntime/test/framework/model_builder_utils.h index b959d5cb2b57..6b8fdf03112c 100644 --- a/onnxruntime/test/framework/model_builder_utils.h +++ b/onnxruntime/test/framework/model_builder_utils.h @@ -57,6 +57,29 @@ struct Type { dim->set_dim_param(d); } } + + static Type MakeSequence(const ONNX_NAMESPACE::TypeProto& element_proto) { + ONNX_NAMESPACE::TypeProto proto; + proto.mutable_sequence_type()->mutable_elem_type()->CopyFrom(element_proto); + return Type(std::move(proto)); + } + + static Type MakeMap(ONNX_NAMESPACE::TensorProto_DataType dtype, const ONNX_NAMESPACE::TypeProto& value_proto) { + ONNX_NAMESPACE::TypeProto proto; + auto& mut_map = *proto.mutable_map_type(); + mut_map.set_key_type(static_cast(dtype)); + mut_map.mutable_value_type()->CopyFrom(value_proto); + return Type(std::move(proto)); + } + + static Type MakeOptional(const ONNX_NAMESPACE::TypeProto& contained_proto) { + ONNX_NAMESPACE::TypeProto proto; + proto.mutable_optional_type()->mutable_elem_type()->CopyFrom(contained_proto); + return Type(std::move(proto)); + } + +private: + explicit Type(ONNX_NAMESPACE::TypeProto type_proto) : value(std::move(type_proto)) {} }; } // namespace modelbuilder diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc new file mode 100644 index 000000000000..1be6297deaf3 --- /dev/null +++ b/onnxruntime/test/framework/type_info_test.cc @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "model_builder_utils.h" + +#include "core/framework/onnxruntime_optional_type_info.h" +#include "core/framework/onnxruntime_map_type_info.h" +#include "core/framework/onnxruntime_sequence_type_info.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/onnxruntime_typeinfo.h" + +namespace onnxruntime { +namespace test { + +namespace mb = modelbuilder; + +TEST(TypeInfoTests, TensorProto) { + mb::Type tensor_type = {1, 2, 3, 4}; + + auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); + ASSERT_NE(nullptr, tensor_type_info->data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); +} + +TEST(TypeInfoTests, SequenceWithTensorElement) { + mb::Type tensor_type = {1, 2, 3, 4}; + auto sequence_proto = mb::Type::MakeSequence(tensor_type.value); + auto seq_type_info = OrtTypeInfo::FromTypeProto(sequence_proto.value); + + ASSERT_EQ(ONNX_TYPE_SEQUENCE, seq_type_info->type); + ASSERT_NE(nullptr, seq_type_info->sequence_type_info); + const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; + + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); +} + +TEST(TypeInfoTests, MapWithTensorValue) { + mb::Type value_type = {1, 2, 3, 4}; + auto map_proto = mb::Type::MakeMap(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, value_type.value); + auto map_type_info = OrtTypeInfo::FromTypeProto(map_proto.value); + + ASSERT_EQ(ONNX_TYPE_MAP, map_type_info->type); + ASSERT_NE(nullptr, map_type_info->map_type_info); + const auto& map_info = *map_type_info->map_type_info; + + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, map_info.map_key_type_); + ASSERT_NE(nullptr, map_info.map_value_type_); + const auto& tensor_type_info = *map_info.map_value_type_; + + ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); +} + +TEST(TypeInfoTests, OptionalWithTensorProto) { + mb::Type tensor_type = {1, 2, 3, 4}; + auto optional_proto = mb::Type::MakeOptional(tensor_type.value); + + auto optional_type_info = OrtTypeInfo::FromTypeProto(optional_proto.value); + + ASSERT_EQ(ONNX_TYPE_OPTIONAL, optional_type_info->type); + ASSERT_NE(nullptr, optional_type_info->optional_type_info); + ASSERT_NE(nullptr, optional_type_info->optional_type_info->contained_type_); + + const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; + ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); + ASSERT_NE(nullptr, contained_type.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file From 850aaefb53fd3198e07f98f2f3950ebb8ff17a17 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 7 Apr 2023 15:43:13 -0700 Subject: [PATCH 7/9] Address comments in C# code --- .../DisposableNamedOnnxValue.shared.cs | 12 ++---- .../InferenceSession.shared.cs | 6 +-- .../NamedOnnxValue.shared.cs | 43 ++++++++++++++++--- .../InferenceTest.cs | 6 +-- .../InferenceTest.netcore.cs | 8 +--- 5 files changed, 47 insertions(+), 28 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs index 3847863a0dec..34e71074d9d9 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs @@ -4,7 +4,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; -using System.Reflection; namespace Microsoft.ML.OnnxRuntime { @@ -56,7 +55,6 @@ public void Dispose() /// /// This class serves as a container for model run output values including /// tensors, sequences of tensors, sequences and maps. - /// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type /// The class must be disposed of. /// It disposes of _ortValueHolder that owns the underlying Ort output value and /// anything else that would need to be disposed by the instance of the class. @@ -100,7 +98,9 @@ private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxVa } /// - /// Construct from a dictionary + /// Construct an instance that would contain a map in a form of a Dictionary + /// Currently a limited number of primitive types are supported as map keys and values. + /// So this is not a full implementation of the map type. /// /// /// @@ -394,11 +394,7 @@ private static DisposableNamedOnnxValue FromNativeMap(string name, OrtValue ortV NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); } - if (valueElemType != TensorElementType.Float) - { - throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value element type: {valueElemType} not supported"); - } - + // The supported combinations of key and value types are taken from the ORT C API. switch (keyElemType) { case TensorElementType.Int64: diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index b40ec845d51e..a6be0afdad09 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -661,7 +661,7 @@ private NodeMetadata LookupInputMetadata(string nodeName) if (!_inputMetadata.TryGetValue(nodeName, out meta) && !_overridableInitializerMetadata.TryGetValue(nodeName, out meta)) { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input/output name: '{nodeName}' is not in the metadata"); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input name: '{nodeName}' is not in the metadata"); } return meta; } @@ -677,7 +677,7 @@ private NodeMetadata LookupOutputMetadata(string nodeName) NodeMetadata meta; if (!_outputMetadata.TryGetValue(nodeName, out meta)) { - throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Input/output name: '{nodeName}' is not in the metadata"); + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Output name: '{nodeName}' is not in the metadata"); } return meta; } @@ -1102,7 +1102,7 @@ internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) return GetOptionalMetadataFromTypeInfo(typeInfo); } - throw new NotImplementedException("Value type not supported in this code"); + throw new OnnxRuntimeException(ErrorCode.NotImplemented, $"Value type: '{valueType}' not supported in this code"); } internal static NodeMetadata GetSequenceMetadataFromTypeInfo(IntPtr typeInfo) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs index 9854d1940f67..5f08daf73806 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs @@ -12,6 +12,9 @@ namespace Microsoft.ML.OnnxRuntime { /// /// The class holds keys and values for the dictionary + /// in a for of two DenseTensors. The class is used to avoid + /// data copy and make these available to the native code. + /// Strings require special handling. /// internal class MapHelper { @@ -36,7 +39,7 @@ internal MapHelper(object keys, object values) /// directly. Thus we are able to avoid copying. /// /// For outputs, tensor buffers works the same as input, providing it matches - /// the expected output shape. For other types (maps and sequences, we create a copy of the data). + /// the expected output shape. For other types (maps and sequences) we create a copy of the data. /// This is because, the class is not Disposable and it is a public interface, thus it can not own /// the underlying OrtValues that must be destroyed before Run() returns. /// @@ -46,7 +49,14 @@ internal MapHelper(object keys, object values) /// It is a recursive structure that may contain Tensors (base case) /// Other sequences and maps. Although the OnnxValueType is exposed, /// the caller is supposed to know the actual data type contained. - /// For that one will need to consult model metadata. + /// + /// The convention is that for tensors, it would contain a DenseTensor instance or + /// anything derived from Tensor. + /// + /// For sequences, it would contain a IList where T is an instance of NamedOnnxValue that + /// would contain a tensor or another type. + /// + /// For Maps, it would contain a IDictionary where K,V are primitive types or strings. /// /// public class NamedOnnxValue @@ -68,7 +78,7 @@ public class NamedOnnxValue /// /// input/output name /// Object that may be a tensor, Dictionary, IList - [Obsolete("This the constructor with valueType or static factory methods")] + [Obsolete("Use constructors with valueType or static factory methods")] protected NamedOnnxValue(string name, Object value) { _name = name; @@ -76,13 +86,30 @@ protected NamedOnnxValue(string name, Object value) ValueType = OnnxValueType.ONNX_TYPE_UNKNOWN; } + /// + /// Constructs an instance that contains a tensor, sequence or optional type. + /// + /// + /// + /// internal NamedOnnxValue(string name, Object value, OnnxValueType valueType) { _name = name; _value = value; ValueType = valueType; + + if (valueType == OnnxValueType.ONNX_TYPE_MAP) + { + throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Use another __ctor for maps"); + } } + /// + /// Use this to construct maps + /// + /// + /// + /// internal NamedOnnxValue(string name, Object value, MapHelper helper) { _name = name; @@ -93,7 +120,7 @@ internal NamedOnnxValue(string name, Object value, MapHelper helper) /// /// Onnx Value Type if known. In general, NamedOnnxValue is able to contain - /// arbitrary objects. + /// arbitrary objects. Please, follow the convention described in the class doc. /// public OnnxValueType ValueType { get; internal set; } @@ -123,7 +150,7 @@ public static NamedOnnxValue CreateFromSequence(string name, IEnumerable v } /// - /// This is a factory method that instantiates NamedOnnxValue. + /// Instantiates NamedOnnxValue that contains IDictionary /// /// Keys type /// Values type @@ -166,6 +193,8 @@ public Tensor AsTensor() /// /// Try-get value as an Enumerable<T>. + /// T is usually a NamedOnnxValue instance that may contain + /// Tensors, Sequences, Maps or optional types /// /// Type /// Enumerable object if contained value is a Enumerable. Null otherwise @@ -178,8 +207,8 @@ public IEnumerable AsEnumerable() /// /// Try-get value as an Dictionary<K,V>. /// - /// Key type - /// Value type + /// Key type currently primitive type only + /// Value type, currently primitive type only /// Dictionary object if contained value is a Dictionary. Null otherwise public IDictionary AsDictionary() { diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 871e90164716..da83b640f257 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -588,7 +588,7 @@ private void ThrowWrongInputName() var container = new List(); container.Add(NamedOnnxValue.CreateFromTensor("wrong_name", tensor)); var ex = Assert.Throws(() => session.Run(container)); - Assert.Contains("Input/output name: 'wrong_name' is not in the metadata", ex.Message); + Assert.Contains("Input name: 'wrong_name' is not in the metadata", ex.Message); session.Dispose(); } @@ -623,7 +623,7 @@ private void ThrowExtraInputs() container.Add(nov1); container.Add(nov2); var ex = Assert.Throws(() => session.Run(container)); - Assert.Contains("Input/output name: 'extra' is not in the metadata", ex.Message); + Assert.Contains("Input name: 'extra' is not in the metadata", ex.Message); session.Dispose(); } @@ -655,7 +655,7 @@ private void ThrowWrongOutputName() // var outputs = new List { NamedOnnxValue.CreateFromTensor("bad_output_name", outputTensor) }; var bad_names = new string[] {"bad_output_name"}; var ex = Assert.Throws(() => session.Run(inputs, bad_names)); - Assert.Contains("Input/output name: 'bad_output_name' is not in the metadata", ex.Message); + Assert.Contains("Output name: 'bad_output_name' is not in the metadata", ex.Message); session.Dispose(); } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index b9785f8b4d63..76518e341f93 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -247,7 +247,6 @@ private void TestTensorRTProviderOptions() { "fp16_test_tiny_yolov2", "ImageScaler is not a registered function/op"}, { "fp16_coreml_FNS-Candy", "ImageScaler is not a registered function/op" }, { "fp16_coreml_LinearRegression_NYCTaxi", "Error in Node:featureVectorizer : No Op registered for FeatureVectorizer with domain_version of 1"}, - // { "test_bidaf", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile." }, { "test_mnist", "Does not run in opset9, runs in other opsets. The model runs but I don't have a data set to debug output locally. Tensors of type ElementType not currently supported in the LoadTensorFromFile" }, { "BERT_Squad", "Could not find an implementation for the nodeMeta bert / embeddings / one_hot:OneHot(9)" }, @@ -471,12 +470,7 @@ private string MatchInputOutputWithFile(string fileName, InferenceSession sessio if (result is null) { - // try matching the file name directly against the input/output name - if (!metadata.TryGetValue(fileName, out result)) - { - throw new InvalidDataException($"Unable to match file: {fileName} to input/output metadata"); - } - nodeName = fileName; + throw new InvalidDataException($"Unable to match file: {fileName} to input/output metadata"); } return nodeName; } From 99d09cf3a40ab7cdbcebd8f99db10f626eede027 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 10 Apr 2023 10:42:17 -0700 Subject: [PATCH 8/9] Disable Maps tests when ML_OPS are disabled --- onnxruntime/test/framework/type_info_test.cc | 36 +++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index 1be6297deaf3..ee787fb071d9 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -42,6 +42,24 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); } +TEST(TypeInfoTests, OptionalWithTensorProto) { + mb::Type tensor_type = {1, 2, 3, 4}; + auto optional_proto = mb::Type::MakeOptional(tensor_type.value); + + auto optional_type_info = OrtTypeInfo::FromTypeProto(optional_proto.value); + + ASSERT_EQ(ONNX_TYPE_OPTIONAL, optional_type_info->type); + ASSERT_NE(nullptr, optional_type_info->optional_type_info); + ASSERT_NE(nullptr, optional_type_info->optional_type_info->contained_type_); + + const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; + ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); + ASSERT_NE(nullptr, contained_type.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); +} + +#if !defined(DISABLE_ML_OPS) TEST(TypeInfoTests, MapWithTensorValue) { mb::Type value_type = {1, 2, 3, 4}; auto map_proto = mb::Type::MakeMap(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, value_type.value); @@ -60,23 +78,7 @@ TEST(TypeInfoTests, MapWithTensorValue) { ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); } - -TEST(TypeInfoTests, OptionalWithTensorProto) { - mb::Type tensor_type = {1, 2, 3, 4}; - auto optional_proto = mb::Type::MakeOptional(tensor_type.value); - - auto optional_type_info = OrtTypeInfo::FromTypeProto(optional_proto.value); - - ASSERT_EQ(ONNX_TYPE_OPTIONAL, optional_type_info->type); - ASSERT_NE(nullptr, optional_type_info->optional_type_info); - ASSERT_NE(nullptr, optional_type_info->optional_type_info->contained_type_); - - const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; - ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); - ASSERT_NE(nullptr, contained_type.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); -} +#endif } // namespace test } // namespace onnxruntime \ No newline at end of file From daff1d91c8587f738e5ae24667763f35183d4d93 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 10 Apr 2023 18:29:02 -0700 Subject: [PATCH 9/9] Address review comments --- onnxruntime/core/framework/onnxruntime_map_type_info.cc | 5 ++--- onnxruntime/core/framework/onnxruntime_map_type_info.h | 2 +- onnxruntime/core/framework/tensor_type_and_shape.h | 4 ++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 4e9acdcf5ef0..b87ea179c070 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -75,9 +75,8 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { std::unique_ptr OrtMapTypeInfo::FromTypeProto( const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kMapType) { - ORT_THROW("type_proto is not of type map!"); - } + + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kMapType, "type_proto is not of type map!"); // Get the key type of the map const auto& type_proto_map = type_proto.map_type(); diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.h b/onnxruntime/core/framework/onnxruntime_map_type_info.h index 1d47d51ddf22..6b20a94b30a5 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.h +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.h @@ -25,7 +25,7 @@ struct OrtMapTypeInfo { OrtMapTypeInfo(ONNXTensorElementDataType map_key_type, std::unique_ptr map_value_type) noexcept; ~OrtMapTypeInfo(); - OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; + OrtMapTypeInfo(const OrtMapTypeInfo& other) = delete; OrtMapTypeInfo& operator=(const OrtMapTypeInfo& other) = delete; }; diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 149d8c156821..9da1d8cd6414 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -44,6 +44,10 @@ struct OrtTensorTypeAndShapeInfo { const std::vector* dim_params, const ONNX_NAMESPACE::TypeProto&); + // We provide Clone() here to satisfy the existing coding pattern + // as we need copies made on the heap even though we achieve that + // via a copy __ctor which can not be made private due to make_unique + // which is a requirement. std::unique_ptr Clone() const { return std::make_unique(*this); }