diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs index 18b6977394df..a4f8832707c1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.cs @@ -5,7 +5,7 @@ using System; using System.Buffers; using System.Collections.Generic; - +using System.Diagnostics; namespace Microsoft.ML.OnnxRuntime { @@ -50,45 +50,75 @@ public void Dispose() #endregion } + /// + /// This class serves as a container for model run output values including + /// tensors, sequences of tensors, sequences and maps + /// It extends NamedOnnxValue, exposes the OnnxValueType and Tensor type + /// The class must be disposed of. + /// It disposes of _ortValueHolder that owns the underlying Ort output value or + /// anything that the class that implements that interfaces needs to dispose. + /// Use factory method CreateFromOrtValue to obtain an instance of the class. + /// public class DisposableNamedOnnxValue : NamedOnnxValue, IDisposable { - private NativeMemoryHandler _nativeMemoryManager; - private TensorElementType _elementType; - private OnnxValueType _onnxValueType; + private IOrtValueOwner _ortValueHolder; private bool _disposed = false; - private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, NativeMemoryHandler nativeMemoryManager) + private DisposableNamedOnnxValue(string name, Object value, OnnxValueType onnxValueType, TensorElementType elementType, IOrtValueOwner ortValueHolder) : base(name, value) { - _onnxValueType = onnxValueType; - _elementType = elementType; - _nativeMemoryManager = nativeMemoryManager; + _ortValueHolder = ortValueHolder; + ValueType = onnxValueType; + ElementType = elementType; } /// - /// Overrides the base class method. Since the instance already has access to the - /// underlying OrtValue handle, it returns an instance of OrtValue that does not own the raw handle + /// Returns OnnxValueType + /// + public OnnxValueType ValueType { get; } + + /// + /// Only valid if ValueType is Tensor + /// + public TensorElementType ElementType { get; } + + /// + /// Overrides the base class method. Since the instance already owns underlying OrtValue handle, + /// it returns an instance of OrtValue that does not own the raw handle /// that to the output onnxValue. With respect to pinnedMemoryHandle, it has no operation - /// to do, as this class doesn't maintain a managed buffer. It doesn't have to maintain it - /// as it already is associated with the object of interest (native OrtValue) + /// to do, as this class maintains a native buffer via _ortValueHolder and the memory will be + /// disposed by it. This is the case when we are dealing with an OrtValue that is backed by native memory + /// and not by pinned managed memory. /// - /// + /// always set to null + /// An instance of OrtValue that does not own underlying memory internal override OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) { + if(_ortValueHolder == null) + { + throw new InvalidOperationException("The instance of this class does not own any OrtValues"); + } // PinnedMemoryHandle holds the default value as DisposableNamedOnnxValue // doesn't hold any managed buffer (that needs to be pinned) pinnedMemoryHandle = null; - // Assign the onnxValue by querying this instance's NativeOnnxTensorMemory instance - return new OrtValue(_nativeMemoryManager.Handle, false); + // Return non-owning instance of OrtValue + return _ortValueHolder.Value; } - internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, IntPtr nativeOnnxValue) + /// + /// Creates an instance of DisposableNamedOnnxValue and takes ownership of ortValueElement + /// on success. + /// + /// name of the value + /// underlying OrtValue + /// + internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, OrtValue ortValue) { DisposableNamedOnnxValue result = null; /* Get Tensor element type */ //TODO: Assumed value is Tensor, need to support non-tensor types in future IntPtr typeAndShape = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValue, out typeAndShape)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(ortValue.Handle, out typeAndShape)); TensorElementType elemType = TensorElementType.DataTypeMax; try { @@ -104,40 +134,40 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, switch (elemType) { case TensorElementType.Float: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Double: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Int16: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.UInt16: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Int32: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.UInt32: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Int64: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.UInt64: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.UInt8: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Int8: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.String: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; case TensorElementType.Bool: - result = DisposableNamedOnnxValueFromNativeTensor(name, nativeOnnxValue); + result = DisposableNamedOnnxValueFromNativeTensor(name, ortValue); break; default: throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported"); @@ -149,115 +179,223 @@ internal static DisposableNamedOnnxValue CreateTensorFromOnnxValue(string name, internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue) { - var result = CreateFromOnnxValue(name, ortValue.Handle, OrtAllocator.DefaultInstance); - ortValue.Disown(); - return result; + return CreateFromOrtValue(name, ortValue, OrtAllocator.DefaultInstance); } - internal static DisposableNamedOnnxValue CreateFromOnnxValue(string name, IntPtr nativeOnnxValue, OrtAllocator allocator) + internal static DisposableNamedOnnxValue CreateFromOrtValue(string name, OrtValue ortValue, OrtAllocator allocator) { + DisposableNamedOnnxValue result = null; + IntPtr valueType; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(nativeOnnxValue, out valueType)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueType(ortValue.Handle, out valueType)); OnnxValueType onnxValueType = (OnnxValueType)valueType; switch (onnxValueType) { case OnnxValueType.ONNX_TYPE_TENSOR: - return CreateTensorFromOnnxValue(name, nativeOnnxValue); + result = CreateTensorFromOnnxValue(name, ortValue); + break; case OnnxValueType.ONNX_TYPE_SEQUENCE: - IntPtr count = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueCount(nativeOnnxValue, out count)); - var sequence = new DisposableList(count.ToInt32()); - for (int i = 0; i < count.ToInt32(); i++) - { - IntPtr nativeOnnxValueSeq; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(nativeOnnxValue, i, allocator.Pointer, out nativeOnnxValueSeq)); - sequence.Add(CreateFromOnnxValue(string.Empty, nativeOnnxValueSeq, allocator)); - } - return new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, TensorElementType.DataTypeMax, null); + result = DisposableNamedOnnxValueFromSequence(name, ortValue, allocator); + break; case OnnxValueType.ONNX_TYPE_MAP: - IntPtr nativeOnnxValueMapKeys = IntPtr.Zero; - IntPtr nativeOnnxValueMapValues = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(nativeOnnxValue, 0, allocator.Pointer, out nativeOnnxValueMapKeys)); - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(nativeOnnxValue, 1, allocator.Pointer, out nativeOnnxValueMapValues)); - - IntPtr typeAndShape = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); - TensorElementType elemType = TensorElementType.DataTypeMax; - try - { - IntPtr el_type; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); - elemType = (TensorElementType)el_type; - } - finally - { - NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); - } - - switch (elemType) - { - case TensorElementType.Int64: - return DisposableNamedOnnxValueFromNativeMap(string.Empty, nativeOnnxValueMapKeys, nativeOnnxValueMapValues); - case TensorElementType.String: - return DisposableNamedOnnxValueFromNativeMap(string.Empty, nativeOnnxValueMapKeys, nativeOnnxValueMapValues); - default: - throw new NotSupportedException("Map of element type: " + elemType + " is not supported"); - } + result = DisposableNamedOnnxValueFromNativeMap(name, ortValue, allocator); + break; default: throw new NotSupportedException("OnnxValueType : " + onnxValueType + " is not supported"); } + return result; } - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor(string name, IntPtr nativeOnnxValue) + /// + /// This method creates an instance of DisposableNamedOnnxValue that has possession of ortValueElement + /// native memory Tensor and returns it to the caller. The original ortValueElement argument looses + /// ownership of the native ortValueElement handle, however, the caller is still responsible for disposing them + /// on exception. Disposing of OrtValue that has no ownership is a no-op and fine. + /// + /// data type + /// name of the output + /// native tensor + /// DisposableNamedOnnxValue instance + private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeTensor(string name, OrtValue ortValue) { if (typeof(T) == typeof(string)) { - var nativeTensorWrapper = new NativeOnnxTensorMemory(nativeOnnxValue); - var dt = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); + var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); + try + { + var dt = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); + return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); + } catch(Exception e) + { + nativeTensorWrapper.Dispose(); + throw e; + } } else { - NativeOnnxTensorMemory nativeTensorWrapper = new NativeOnnxTensorMemory(nativeOnnxValue); - DenseTensor dt = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); - return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); + NativeOnnxTensorMemory nativeTensorWrapper = new NativeOnnxTensorMemory(ortValue); + try + { + DenseTensor dt = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); + return new DisposableNamedOnnxValue(name, dt, OnnxValueType.ONNX_TYPE_TENSOR, nativeTensorWrapper.ElementType, nativeTensorWrapper); + } + catch (Exception e) + { + nativeTensorWrapper.Dispose(); + throw e; + } } } - private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(string name, IntPtr nativeOnnxValueKeys, IntPtr nativeOnnxValueValues) + /// + /// This method will create an instance of DisposableNamedOnnxValue that will own ortSequenceValue + /// an all disposable native objects that are elements of the sequence + /// + /// + /// ortValueElement that has native sequence + /// used allocator + /// DisposableNamedOnnxValue + private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromSequence(string name, OrtValue ortValueSequence, OrtAllocator allocator) { - var nativeTensorWrapperValues = new NativeOnnxTensorMemory(nativeOnnxValueValues); - var denseTensorValues = new DenseTensor(nativeTensorWrapperValues.Memory, nativeTensorWrapperValues.Dimensions); + DisposableNamedOnnxValue result = null; + IntPtr count; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValueCount(ortValueSequence.Handle, out count)); + var sequence = new DisposableList(count.ToInt32()); + try + { + for (int i = 0; i < count.ToInt32(); i++) + { + IntPtr nativeOnnxValueSeq; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueSequence.Handle, i, allocator.Pointer, out nativeOnnxValueSeq)); + using (var ortValueElement = new OrtValue(nativeOnnxValueSeq)) + { + // Will take ownership or throw + sequence.Add(CreateFromOrtValue(string.Empty, ortValueElement, allocator)); + } + } + // NativeOrtValueCollectionOwner will take ownership of ortValueSequence and will make sure sequence + // is also disposed. + var nativeCollectionManager = new NativeOrtValueCollectionOwner(ortValueSequence, sequence); + result = new DisposableNamedOnnxValue(name, sequence, OnnxValueType.ONNX_TYPE_SEQUENCE, TensorElementType.DataTypeMax, nativeCollectionManager); + } + catch (Exception e) + { + sequence.Dispose(); + throw e; + } + return result; + } - if (typeof(K) == typeof(string)) + /// + /// Will extract keys and values from the map and create a DisposableNamedOnnxValue from it + /// + /// name of the output + /// ortValue that represents a map. + /// This function does not take ownership of the map as it we copy all keys an values into a dictionary. We let the caller dispose of it + /// + /// DisposableNamedOnnxValue + private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMap(string name, OrtValue ortValueMap, OrtAllocator allocator) + { + DisposableNamedOnnxValue result = null; + // Map processing is currently not recursing. It is assumed to contain + // only primitive types and strings tensors. No sequences or maps. + // The data is being copied to a dictionary and all ortValues are being disposed. + // not mapped for client consumption. + using (var cleanUpList = new DisposableList()) { - var map = new Dictionary(); - var nativeTensorWrapper = new NativeOnnxTensorMemory(nativeOnnxValueKeys); - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); - for (var i = 0; i < denseTensorKeys.Length; i++) + // Take possession of the map ortValueElement + IntPtr nativeOnnxValueMapKeys = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 0, allocator.Pointer, out nativeOnnxValueMapKeys)); + var ortValueKeys = new OrtValue(nativeOnnxValueMapKeys); + cleanUpList.Add(ortValueKeys); + + IntPtr nativeOnnxValueMapValues = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetValue(ortValueMap.Handle, 1, allocator.Pointer, out nativeOnnxValueMapValues)); + var ortValueValues = new OrtValue(nativeOnnxValueMapValues); + cleanUpList.Add(ortValueValues); + + IntPtr typeAndShape = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(nativeOnnxValueMapKeys, out typeAndShape)); + TensorElementType elemType = TensorElementType.DataTypeMax; + try + { + IntPtr el_type; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(typeAndShape, out el_type)); + elemType = (TensorElementType)el_type; + } + finally { - map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + NativeMethods.OrtReleaseTensorTypeAndShapeInfo(typeAndShape); + } + + /// XXX: This code always assumes that the value type is float and makes no checks + /// similar to that of the key. Also Map type in general can also be another sequence or map, + /// not just a tensor + switch (elemType) + { + case TensorElementType.Int64: + result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + break; + case TensorElementType.String: + result = DisposableNamedOnnxValueFromNativeMapElements(string.Empty, ortValueKeys, ortValueValues); + break; + default: + throw new NotSupportedException("Map of element type: " + elemType + " is not supported"); } - // release native memory - nativeTensorWrapperValues.Dispose(); - nativeTensorWrapper.Dispose(); - return new DisposableNamedOnnxValue(string.Empty, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); } - else + return result; + } + + + /// + /// This method maps keys and values of the map and copies them into a Dictionary + /// and returns as an instance of DisposableNamedOnnxValue that does not own or dispose + /// any onnx/ortValueElement. The method takes possession of ortValueTensorKeys and ortValueTensorValues + /// and disposes of them. The original ortValueElement looses ownership of the Tensor. The caller is still responsible + /// for disposing these arguments. Disposing ortValueElement that does not have ownership is a no-op, however, either + /// of the arguments may still need to be disposed on exception. + /// + /// key type + /// value type + /// name of the output parameter + /// tensor with map keys. + /// tensor with map values + /// instance of DisposableNamedOnnxValue with Dictionary + private static DisposableNamedOnnxValue DisposableNamedOnnxValueFromNativeMapElements(string name, + OrtValue ortValueTensorKeys, OrtValue ortValueTensorValues) + { + using (var nativeTensorWrapperValues = new NativeOnnxTensorMemory(ortValueTensorValues)) { - var map = new Dictionary(); - var nativeTensorWrapper = new NativeOnnxTensorMemory(nativeOnnxValueKeys); - var denseTensorKeys = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); - for (var i = 0; i < denseTensorKeys.Length; i++) + var denseTensorValues = new DenseTensor(nativeTensorWrapperValues.Memory, nativeTensorWrapperValues.Dimensions); + + if (typeof(K) == typeof(string)) { - map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + var map = new Dictionary(); + using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + { + var denseTensorKeys = new DenseTensor(nativeTensorWrapper.GetBytesAsStringMemory(), nativeTensorWrapper.Dimensions); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + } + } + else + { + var map = new Dictionary(); + using (var nativeTensorWrapper = new NativeOnnxTensorMemory(ortValueTensorKeys)) + { + var denseTensorKeys = new DenseTensor(nativeTensorWrapper.Memory, nativeTensorWrapper.Dimensions); + for (var i = 0; i < denseTensorKeys.Length; i++) + { + map.Add(denseTensorKeys.GetValue(i), denseTensorValues.GetValue(i)); + } + return new DisposableNamedOnnxValue(name, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); + } } - // release native memory - nativeTensorWrapperValues.Dispose(); - nativeTensorWrapper.Dispose(); - return new DisposableNamedOnnxValue(string.Empty, map, OnnxValueType.ONNX_TYPE_MAP, TensorElementType.DataTypeMax, null); } } @@ -273,10 +411,11 @@ protected virtual void Dispose(bool disposing) // dispose managed state (managed objects). if (disposing) { - if (_nativeMemoryManager != null) + // _ortValueHolder can be null when no native memory is involved + if (_ortValueHolder != null) { - _nativeMemoryManager.Dispose(); - _nativeMemoryManager = null; + _ortValueHolder.Dispose(); + _ortValueHolder = null; } } _disposed = true; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 9bd81a0c4bd4..9e6184eac9d1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -502,8 +502,7 @@ public IDisposableReadOnlyCollection RunWithBindingAnd for (int i = 0; i < outputNames.Length; ++i) { var ortValue = ortValues.ElementAt(i); - result.Add(DisposableNamedOnnxValue.CreateTensorFromOnnxValue(outputNames[i], ortValue.Handle)); - ortValue.Disown(); + result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames[i], ortValue)); } } catch(Exception e) { @@ -696,7 +695,7 @@ private void Init(byte[] modelData, SessionOptions options) /// /// Initializes the session object with a native session handle /// - /// Handle of a native session object + /// Value of a native session object /// Session options private void InitWithSessionHandle(IntPtr session, SessionOptions options) { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs index 4b991624e30a..d87de4db9354 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.cs @@ -5,13 +5,13 @@ using System; using System.Buffers; using System.Collections.Generic; -using System.Runtime.InteropServices.ComTypes; namespace Microsoft.ML.OnnxRuntime { /// - /// The name of the class is a misnomer, it does not hold any - /// Onnx values + /// The class associates a name with an Object. Currently it supports Tensor + /// as possible objects. The name of the class is a misnomer, it does not hold any + /// Onnx values. /// public class NamedOnnxValue { @@ -24,6 +24,14 @@ protected NamedOnnxValue(string name, Object value) _value = value; } + /// + /// This is a factory method that instantiates NamedOnnxValue + /// and associated name with an instance of a Tensor + /// + /// + /// name + /// Tensor + /// public static NamedOnnxValue CreateFromTensor(string name, Tensor value) { return new NamedOnnxValue(name, value); @@ -65,11 +73,12 @@ public IDictionary AsDictionary() } /// - /// Pin the underlying memory and create native onnx value + /// Pin the underlying memory and create an instance of OrtValue + /// based on the pinned managed memory. The caller is responsible for Disposing + /// both OrtValue and pinnedMemoryHandle /// - /// - /// - /// + /// dispose after returned OrtValus is disposed + /// an instance of OrtValue. The lifespan of OrtValue must overlap pinnedMemoryHandle internal virtual OrtValue ToOrtValue(out MemoryHandle? pinnedMemoryHandle) { return OrtValue.CreateFromTensorObject(_value, out pinnedMemoryHandle, out TensorElementType elementType); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs index b62ac98a2bf3..61ac3324b6b0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxTensorMemory.cs @@ -4,39 +4,100 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Buffers; +using System.Diagnostics; +using System.Runtime.InteropServices; using System.Text; -using System.Threading; namespace Microsoft.ML.OnnxRuntime { /// - /// TODO: dmitrism -> Get rid of this class. - /// A non-public interface detailing the contract to be honored by NativeOnnxTensorMemory + /// Provides access from the underlying object that owns disposable OrtValue + /// The returned value does not own the actual memory and does nothing on Dispose() /// - internal interface NativeMemoryHandler : IDisposable + internal interface IOrtValueOwner : IDisposable { - IntPtr Handle { get; } + OrtValue Value { get; } } - internal class NativeOnnxTensorMemory : MemoryManager, NativeMemoryHandler + /// + /// This class is used in conjunction with DisposableNamedOnnxValue to + /// own native collection OrtValue and dispose of it along with any DisposableNamedOnnxValues + /// + internal class NativeOrtValueCollectionOwner : IOrtValueOwner, IDisposable { - private bool _disposed; - private IntPtr _onnxValueHandle; // pointer to onnxvalue object in native + private OrtValue _ortValue; + private DisposableList _disposables; + bool _disposed = false; + + internal NativeOrtValueCollectionOwner(OrtValue ortValue, DisposableList disposables) + { + Debug.Assert(ortValue.IsOwned); + _ortValue = new OrtValue(ortValue.Disown()); + _disposables = disposables; + } + + #region IOrtValueOwner + /// + /// Returns a non-owning ortValue + /// + public OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } } + #endregion IOrtValueOwner + + #region Disposable + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + // dispose managed state (managed objects). + if (disposing) + { + if(_disposables != null) + { + _disposables.Dispose(); + _disposables = null; + } + // _ortValueHolder can be null when no native memory is involved + if (_ortValue != null) + { + _ortValue.Dispose(); + _ortValue = null; + } + _disposed = true; + } + } + public void Dispose() + { + // Do not change this code. Put cleanup code in Dispose(bool disposing) above. + Dispose(true); + GC.SuppressFinalize(this); + } + #endregion Disposable + } + + /// + /// This helper class owns the underlying OrtValue that is assumed to be a Tensor, + /// it does not support any other ortValues and caches Tensor properties. + /// + /// + internal class NativeOnnxTensorMemory : MemoryManager, IOrtValueOwner + { + private OrtValue _ortValue; // Disposable private IntPtr _dataBufferPointer; // pointer to mutable tensor data in native memory private string[] _dataBufferAsString; // string tensor values copied into managed memory - private Tensors.TensorElementType _elementType; - private int _elementCount; - private int _elementWidth; - private int[] _dimensions; - public NativeOnnxTensorMemory(IntPtr onnxValueHandle) + /// + /// Constructs an instance and takes ownership of ortValue on success + /// + /// ortValue that is a Tensor + public NativeOnnxTensorMemory(OrtValue ortValue) { Type type = null; int width = 0; - _onnxValueHandle = onnxValueHandle; - _disposed = false; IntPtr typeAndShape = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(onnxValueHandle, out typeAndShape)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(ortValue.Handle, out typeAndShape)); try { TensorElementType elemType; @@ -50,8 +111,8 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle) if (typeof(T) != type) throw new NotSupportedException(nameof(NativeOnnxTensorMemory) + " does not support T = " + nameof(T)); - _elementType = elemType; - _elementWidth = width; + ElementType = elemType; + ElementWidth = width; UIntPtr dimension; long count; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShape, out dimension)); @@ -66,24 +127,24 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle) } long[] shape = new long[dimension.ToUInt64()]; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(typeAndShape, shape, dimension)); //Note: shape must be alive during the call + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(typeAndShape, shape, dimension)); //Note: shape must be alive during the call - _elementCount = (int)count; - _dimensions = new int[dimension.ToUInt64()]; + Count = (int)count; + Dimensions = new int[dimension.ToUInt64()]; for (ulong i = 0; i < dimension.ToUInt64(); i++) { - _dimensions[i] = (int)shape[i]; + Dimensions[i] = (int)shape[i]; } if (typeof(T) != typeof(string)) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(_onnxValueHandle, out _dataBufferPointer)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorMutableData(ortValue.Handle, out _dataBufferPointer)); } else { UIntPtr strLen; - var offsets = new UIntPtr[_elementCount]; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorDataLength(_onnxValueHandle, out strLen)); + var offsets = new UIntPtr[Count]; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetStringTensorDataLength(ortValue.Handle, out strLen)); var dataBuffer = new byte[strLen.ToUInt64()]; using (var dataBufferHandle = new Memory(dataBuffer).Pin()) @@ -94,11 +155,11 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle) _dataBufferPointer = (IntPtr)dataBufferHandle.Pointer; NativeApiStatus.VerifySuccess( NativeMethods.OrtGetStringTensorContent( - _onnxValueHandle, _dataBufferPointer, strLen, + ortValue.Handle, _dataBufferPointer, strLen, (IntPtr)offsetMemoryHandle.Pointer, - (UIntPtr)_elementCount)); + (UIntPtr)Count)); } - _dataBufferAsString = new string[_elementCount]; + _dataBufferAsString = new string[Count]; for (var i = 0; i < offsets.Length; i++) { @@ -110,6 +171,8 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle) } } } + // Transfer ownership + _ortValue = new OrtValue(ortValue.Disown()); } finally { @@ -117,20 +180,28 @@ public NativeOnnxTensorMemory(IntPtr onnxValueHandle) } } - public IntPtr Handle { get { return _onnxValueHandle; } } + /// + /// Returns a non-owning copy of OrtValue so the + /// result can not release native memory + /// + public OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } } - public bool IsDisposed => _disposed; + public bool IsDisposed { get; private set; } = false; - public int[] Dimensions => _dimensions; + public int[] Dimensions { get; } - public int Rank => _dimensions.Length; + public int Rank => Dimensions.Length; - public int Count => _elementCount; + public int Count { get; } - public int ElementWidth => _elementWidth; + public int ElementWidth { get; } - public Tensors.TensorElementType ElementType => _elementType; + public Tensors.TensorElementType ElementType { get; } + /// + /// Used by MemoryManager to produce Memory Property + /// + /// SpanT public override Span GetSpan() { if (IsDisposed) @@ -138,12 +209,11 @@ public override Span GetSpan() Span span = null; unsafe { - span = new Span((void*)_dataBufferPointer, _elementCount); + span = new Span((void*)_dataBufferPointer, Count); } return span; } - public Memory GetBytesAsStringMemory() { if (IsDisposed) @@ -155,23 +225,26 @@ public Memory GetBytesAsStringMemory() return (_dataBufferAsString == null) ? new Memory() : new Memory(_dataBufferAsString); } + /// + /// Satisfy MemoryManager abstract implementation + /// + /// + /// public override MemoryHandle Pin(int elementIndex = 0) { //Note: always pin the full buffer and return unsafe { - if (elementIndex >= _elementCount) + if (elementIndex >= Count) { throw new ArgumentOutOfRangeException(nameof(elementIndex)); } - return new MemoryHandle((void*)((int)_dataBufferPointer + elementIndex * _elementWidth)); //could not use Unsafe.Add + return new MemoryHandle((void*)((int)_dataBufferPointer + elementIndex * ElementWidth)); //could not use Unsafe.Add } } // MemoryHandle returned above by Pin() should be disposed. // Unpin() is purely to satisfy the interface. - // TODO: This class needs work. It is not clear what happens - // if the MemoryHandle remains alive and this class gets Disposed. public override void Unpin() { } public void Dispose() @@ -182,25 +255,17 @@ public void Dispose() protected override void Dispose(bool disposing) { - if(_disposed) + if (IsDisposed) { return; } - if (_onnxValueHandle != IntPtr.Zero) + if (_ortValue != null) { - NativeMethods.OrtReleaseValue(_onnxValueHandle); - _onnxValueHandle = IntPtr.Zero; + _ortValue.Dispose(); + _ortValue = null; } - - _disposed = true; - } - - protected override bool TryGetArray(out ArraySegment arraySegment) - { - // cannot expose managed array - arraySegment = default(ArraySegment); - return false; + IsDisposed = true; } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index f20a04ea8172..006f90ce762e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -1577,27 +1577,35 @@ private void TestModelSequenceOfMapIntFloat() using (var outputs = session.Run(container)) { // first output is a tensor containing label - var outNode1 = outputs.ElementAtOrDefault(0); - Assert.Equal("label", outNode1.Name); + var outNode0 = outputs.ElementAtOrDefault(0); + Assert.Equal("label", outNode0.Name); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType); + Assert.Equal(TensorElementType.Int64, (TensorElementType)outNode0.ElementType); // try-cast as a tensor - var outLabelTensor = outNode1.AsTensor(); + var outLabelTensor = outNode0.AsTensor(); + Assert.NotNull(outLabelTensor); - // Label 1 should have highest probaility + // Label 1 should have highest probability Assert.Equal(1, outLabelTensor[0]); // second output is a sequence> // try-cast to an sequence of NOV - var outNode2 = outputs.ElementAtOrDefault(1); - Assert.Equal("probabilities", outNode2.Name); + var outNode1 = outputs.ElementAtOrDefault(1); + Assert.Equal("probabilities", outNode1.Name); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outNode1.ValueType); // try-cast to an sequence of NOV - var seq = outNode2.AsEnumerable(); + var seq = outNode1.AsEnumerable(); + Assert.NotNull(seq); + // Try-cast into DisposableNov so we can control and check the process + // try-cast first element in sequence to map/dictionary type if (System.Environment.Is64BitProcess) { var map = seq.First().AsDictionary(); + Assert.NotNull(map); Assert.Equal(0.25938290, map[0], 6); Assert.Equal(0.40904793, map[1], 6); Assert.Equal(0.33156919, map[2], 6); @@ -1605,6 +1613,7 @@ private void TestModelSequenceOfMapIntFloat() else // 32-bit { var map = seq.First().AsDictionary(); + Assert.NotNull(map); Assert.Equal(0.25938290, map[0], 6); Assert.Equal(0.40904793, map[1], 6); Assert.Equal(0.33156919, map[2], 6); @@ -1638,25 +1647,30 @@ private void TestModelSequenceOfMapStringFloat() using (var outputs = session.Run(container)) { // first output is a tensor containing label - var outNode1 = outputs.ElementAtOrDefault(0); - Assert.Equal("label", outNode1.Name); + var outNode0 = outputs.ElementAtOrDefault(0); + Assert.Equal("label", outNode0.Name); + Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, outNode0.ValueType); + Assert.Equal(TensorElementType.String, (TensorElementType)outNode0.ElementType); // try-cast as a tensor - var outLabelTensor = outNode1.AsTensor(); + var outLabelTensor = outNode0.AsTensor(); + Assert.NotNull(outLabelTensor); - // Label 1 should have highest probaility + // Label 1 should have highest probability Assert.Equal("1", outLabelTensor[0]); // second output is a sequence> // try-cast to an sequence of NOV - var outNode2 = outputs.ElementAtOrDefault(1); - Assert.Equal("probabilities", outNode2.Name); + var outNode1 = outputs.ElementAtOrDefault(1); + Assert.Equal("probabilities", outNode1.Name); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outNode1.ValueType); // try-cast to an sequence of NOV - var seq = outNode2.AsEnumerable(); + var seq = outNode1.AsEnumerable(); // try-cast first element in sequence to map/dictionary type var map = seq.First().AsDictionary(); + Assert.NotNull(map); //verify values are valid Assert.Equal(0.25938290, map["0"], 6); Assert.Equal(0.40904793, map["1"], 6); @@ -1693,6 +1707,7 @@ private void TestModelSequenceOfTensors() // try-cast to an sequence of NOV var outNode = outputs.ElementAtOrDefault(0); Assert.Equal("output_sequence", outNode.Name); + Assert.Equal(OnnxValueType.ONNX_TYPE_SEQUENCE, outNode.ValueType); // try-cast to an sequence of NOV var seq = outNode.AsEnumerable(); @@ -1703,6 +1718,8 @@ private void TestModelSequenceOfTensors() // try-cast the elements in sequence to tensor type var firstTensorInOuputSequence = seq.First().AsTensor(); var secondTensorInOuputSequence = seq.Last().AsTensor(); + Assert.NotNull(firstTensorInOuputSequence); + Assert.NotNull(secondTensorInOuputSequence); // make sure the tensors in the output sequence hold the correct values Assert.True(firstTensorInOuputSequence.GetValue(0) == 1);