Skip to content

Commit

Permalink
Regenerate proto code from the current onnx protos
Browse files Browse the repository at this point in the history
Add native methods from the merge
Add Test Protobuf data
Implement test sequence input loading
Optimize Input/Output names conversion and validation
Introduce OnnxValue to NamedOnnxValue, rename NativeOnnxTensorMemory
Rework ToOrtValue interface
Implement ManagedProjections
Make sure all required map types are supported
Generate input OrtValue using ManagedOnnxType
Implement optional support, partial map support.
Fix optional issues. Provide details for failing tests
Comment out two tests due to invalid test data
  • Loading branch information
yuslepukhin committed Apr 3, 2023
1 parent ae93f2e commit 9fb3811
Show file tree
Hide file tree
Showing 19 changed files with 5,208 additions and 498 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ onnxprofile_profile_test_*.json
/csharp/packages
/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.targets
/csharp/src/Microsoft.ML.OnnxRuntime/targets/**/*.props
/csharp/**/*.vcxproj.user
cmake/external/FeaturizersLibrary/
# Java specific ignores
java/.gradle
Expand Down
332 changes: 224 additions & 108 deletions csharp/src/Microsoft.ML.OnnxRuntime/DisposableNamedOnnxValue.shared.cs

Large diffs are not rendered by default.

276 changes: 205 additions & 71 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs

Large diffs are not rendered by default.

277 changes: 277 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// The class helps to feed the NamedOnnxValue as inference input.
/// It projects managed classes to OrtValues so they can be consumed
/// by the native onnxruntime library. if possible, it will avoid copying data.
/// The NamedOnnxValue can be a tensor, sequence or map.
/// For recursive structures, create nested NamedOnnxValue instances.
/// For example, a sequence instance would contain a list of NamedOnnxValue instances
/// that in turn may represent tensors or other ONNX values.
/// </summary>
internal class ManagedTypeProjection : IDisposable
{
readonly DisposableList<IDisposable> _disposables;
readonly OrtValue _ortValue;
bool _disposed = false;

/// <summary>
/// Provides access to non-owning instance of OrtValue
/// </summary>
/// <value>Provides access to the OrtValue to be used as input</value>
internal OrtValue Value { get { return new OrtValue(_ortValue.Handle, false); } }

/// <summary>
/// Constructor to create an input OrtValue projection from managed data
/// </summary>
/// <param name="namedOnnxValue"></param>
/// <param name="metadata"></param>
/// <exception cref="OnnxRuntimeException"></exception>
internal ManagedTypeProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata)
{
int requiredCapacity = 32;
var disposables = new DisposableList<IDisposable>(requiredCapacity);
try
{
_ortValue = CreateDispatchProjection(namedOnnxValue, metadata, disposables);
}
catch (Exception)
{
disposables.Dispose();
throw;
}
_disposables = disposables;
}

/// <summary>
/// Dispatches the creation of the projection
/// </summary>
/// <param name="namedOnnxValue"></param>
/// <param name="metadata"></param>
/// <param name="disposables"></param>
/// <returns></returns>
private OrtValue CreateDispatchProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList<IDisposable> disposables)
{
OrtValue result;

NodeMetadata meta = metadata;
// Use element meta to create types
if (metadata.OnnxValueType == OnnxValueType.ONNX_TYPE_OPTIONAL)
{
meta = metadata.AsOptionalMetadata().ElementMeta;
}

if (namedOnnxValue.ValueType != meta.OnnxValueType)
{
throw new OnnxRuntimeException(ErrorCode.RuntimeException,
$"NamedOnnxValue: {namedOnnxValue.Name} has value type: {namedOnnxValue.ValueType}" +
$" expected: {meta.OnnxValueType} after optional type adjustment");
}

switch (namedOnnxValue.ValueType)
{
case OnnxValueType.ONNX_TYPE_TENSOR:
result = CreateTensorProjection(namedOnnxValue, meta, disposables);
break;
case OnnxValueType.ONNX_TYPE_SEQUENCE:
result = CreateSequenceProjection(namedOnnxValue, meta, disposables);
break;
case OnnxValueType.ONNX_TYPE_MAP:
result = CreateMapProjection(namedOnnxValue, meta, disposables);
break;
default:
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "ManagedTypeProjection can only project tensors, sequences, maps and optional types");
}
return result;
}

/// <summary>
/// The function creates OrtValue objects for each element of the sequence
/// and then creates an OrtValue for the whole sequence.
/// </summary>
/// <param name="namedOnnxValue">NamedOnnxValue containing a IEnumeralbe<NameOnnValue></param>
/// <param name="metadata">sequence metadata</param>
/// <param name="disposables">cleanup list</param>
/// <returns></returns>
/// <exception cref="OnnxRuntimeException"></exception>
private OrtValue CreateSequenceProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata, DisposableList<IDisposable> disposables)
{
OrtValue result = null;
var elementMeta = metadata.AsSequenceMetadata().ElementMeta;
var elementOnnxValue = elementMeta.OnnxValueType;
var seqContainer = namedOnnxValue.AsEnumerable<NamedOnnxValue>();

if (seqContainer is null)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue: {namedOnnxValue.Name} sequence does not contain NamedOnnxValue elements");
}

int capacity = 0;

if (seqContainer is ICollection<NamedOnnxValue>)
{
capacity = ((ICollection<NamedOnnxValue>)seqContainer).Count;
}

// Record all the ortValues belonging to the sequence locally
var sequenceOrtValues = new List<OrtValue>(capacity);
foreach (var element in seqContainer)
{
if (elementOnnxValue != element.ValueType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"NamedOnnxValue: {namedOnnxValue.Name} sequence element expected to be {elementOnnxValue}, received {element.ValueType}");
}

sequenceOrtValues.Add(CreateDispatchProjection(element, elementMeta, disposables));
}

IntPtr[] ortValHandles = new IntPtr[sequenceOrtValues.Count];
for (int i = 0; i < sequenceOrtValues.Count; i++)
{
ortValHandles[i] = sequenceOrtValues[i].Handle;
}

NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles,
(UIntPtr)sequenceOrtValues.Count, (IntPtr)OnnxValueType.ONNX_TYPE_SEQUENCE, out IntPtr sequenceHandle));
result = new OrtValue(sequenceHandle);
disposables.Add(result);

return result;
}

/// <summary>
/// Creates map projection. Since we support only primitive types in maps
/// we map two tensors (keys and values)
/// </summary>
/// <param name="node"></param>
/// <param name="elementMeta"></param>
/// <param name="disposables"></param>
/// <returns>OrtValue</returns>
/// <exception cref="OnnxRuntimeException"></exception>
private OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList<IDisposable> disposables)
{
OrtValue result = null;
var mapMeta = elementMeta.AsMapMetadata();
Debug.Assert(mapMeta != null);
// Maps currently support only primitive types expressed as two parallel tensors and not nested Sequences or Maps

var mapValuesMeta = mapMeta.ValueMetadata;
if (mapValuesMeta.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Node: {node.Name} onnxruntime only supports maps with primitive types values");
}


var keys = node.GetDictionaryKeys();
var ortValueKeys = OrtValue.CreateFromTensorObject(keys,
out MemoryHandle? memoryHandleKeys, out TensorElementType elementTypeKeys);
disposables.Add(ortValueKeys);

if (memoryHandleKeys.HasValue)
{
disposables.Add(memoryHandleKeys);
}

if (elementTypeKeys != mapMeta.KeyDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map key data type supplied: {elementTypeKeys} metadata expected: {mapMeta.KeyDataType}");
}

var values = node.GetDictionaryValues();
var ortValueValues = OrtValue.CreateFromTensorObject(values,
out MemoryHandle? memoryHandleValues, out TensorElementType elementTypeValues);

disposables.Add(ortValueValues);
if (memoryHandleValues.HasValue)
{
disposables.Add(memoryHandleValues);
}

if (elementTypeValues != mapValuesMeta.ElementDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Map value data type supplied: {elementTypeValues} metadata expected: {mapValuesMeta.ElementDataType}");
}

// Create Map OrtValue
IntPtr[] ortValHandles = { ortValueKeys.Handle, ortValueValues.Handle };
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateValue(ortValHandles, (UIntPtr)2,
(IntPtr)OnnxValueType.ONNX_TYPE_MAP, out IntPtr ortValueMap));
result = new OrtValue(ortValueMap);
disposables.Add(result);
return result;
}


/// <summary>
/// This pins memory that is contained within DenseTensor.
/// </summary>
/// <param name="node">NodeOnnxValue containing DenseTensor</param>
/// <param name="elementMeta"></param>
/// <param name="disposables">cleanup list</param>
/// <returns></returns>
/// <exception cref="OnnxRuntimeException"></exception>
private OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta, DisposableList<IDisposable> disposables)
{
var ortValue = OrtValue.CreateFromTensorObject(node.Value,
out MemoryHandle? memoryHandle, out TensorElementType elementType);
disposables.Add(ortValue);

if (memoryHandle.HasValue)
{
disposables.Add(memoryHandle);
}

if (elementType != elementMeta.ElementDataType)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
$"Tensor element data type discovered: {elementType} metadata expected: {elementMeta.ElementDataType}");
}

return ortValue;
}

#region IDisposable
/// <summary>
/// IDisposable implementation
/// </summary>
/// <param name="disposing">true if invoked by Dispose()</param>
protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}

// dispose managed state (managed objects).
if (disposing)
{
_disposables.Dispose();
}
_disposed = true;
}


public void Dispose()
{
Dispose(true);
}

#endregion IDisposable
}
}

Loading

0 comments on commit 9fb3811

Please sign in to comment.