Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C# ] Improve string marshalling and reduce GC pressure #15545

Merged
merged 15 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 142 additions & 98 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

<PropertyGroup>
<Platforms>AnyCPU;x86</Platforms>
<LangVersion>7.2</LangVersion>
<LangVersion>7.3</LangVersion>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\OnnxRuntime.snk</AssemblyOriginatorKeyFile>
Expand Down
49 changes: 30 additions & 19 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ public struct OrtApi
public IntPtr KernelInfoGetConstantInput_tensor;
public IntPtr CastTypeInfoToOptionalTypeInfo;
public IntPtr GetOptionalContainedTypeInfo;
public IntPtr GetStringTensorElementBuffer;
}

internal static class NativeMethods
Expand Down Expand Up @@ -421,6 +422,7 @@ static NativeMethods()
OrtCreateTensorWithDataAsOrtValue = (DOrtCreateTensorWithDataAsOrtValue)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorWithDataAsOrtValue, typeof(DOrtCreateTensorWithDataAsOrtValue));
OrtGetTensorMutableData = (DOrtGetTensorMutableData)Marshal.GetDelegateForFunctionPointer(api_.GetTensorMutableData, typeof(DOrtGetTensorMutableData));
OrtFillStringTensor = (DOrtFillStringTensor)Marshal.GetDelegateForFunctionPointer(api_.FillStringTensor, typeof(DOrtFillStringTensor));
OrtGetStringTensorElementBuffer = (DOrtGetStringTensorElementBuffer)Marshal.GetDelegateForFunctionPointer(api_.GetStringTensorElementBuffer, typeof(DOrtGetStringTensorElementBuffer));
OrtGetStringTensorContent = (DOrtGetStringTensorContent)Marshal.GetDelegateForFunctionPointer(api_.GetStringTensorContent, typeof(DOrtGetStringTensorContent));
OrtGetStringTensorDataLength = (DOrtGetStringTensorDataLength)Marshal.GetDelegateForFunctionPointer(api_.GetStringTensorDataLength, typeof(DOrtGetStringTensorDataLength));
OrtCastTypeInfoToTensorInfo = (DOrtCastTypeInfoToTensorInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToTensorInfo, typeof(DOrtCastTypeInfoToTensorInfo));
Expand Down Expand Up @@ -859,7 +861,7 @@ internal class NativeLib
public static DOrtDisableCpuMemArena OrtDisableCpuMemArena;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, IntPtr /* const char* */logId);
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */logId);
public static DOrtSetSessionLogId OrtSetSessionLogId;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
Expand Down Expand Up @@ -890,8 +892,8 @@ internal class NativeLib
/// <param name="configValue">Config value</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options,
IntPtr /* const char* */configKey,
IntPtr /* const char* */ configValue);
byte[] /* const char* */configKey,
byte[] /* const char* */ configValue);
Comment on lines -937 to +939
Copy link
Contributor

@skottmckay skottmckay Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest when should byte[] be used instead of IntPtr and why (given IntPtr has worked up until now)? #Pending

Copy link
Member Author

@yuslepukhin yuslepukhin Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The managed memory must be pinned before being passed to the native code so it is not GCed or moved. Before this PR we relied on GCHandle like pinning (same as MemoryHandle) which is very expensive as it turns out. The passed IntPtr did not require marshalling.

Chunks of memory allocated on unmanaged heap do not require pinning and can be passed on as IntPtr.

I learned that when one passes blittable array of blittable types it is pinned automatically using 'fixed()' style much more efficient pinning. And we do not have to do try/finally for that. We sort of used it in other places if one looks carefully (IntPtr[] is also considered blittable).

In fact, this was one of the most expensive parts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a guideline that could be documented so that going forward the best thing to do can be captured in that instead of trying to infer it from existing implementation details?

public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry;

//
Expand Down Expand Up @@ -936,7 +938,7 @@ internal class NativeLib
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id);

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_OpenVINO(IntPtr /*(OrtSessionOptions*)*/ options, IntPtr /*(const char*)*/ device_id);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_OpenVINO(IntPtr /*(OrtSessionOptions*)*/ options, byte[] /*(const char*)*/ device_id);

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tensorrt(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
Expand All @@ -945,7 +947,7 @@ internal class NativeLib
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);

[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tvm(IntPtr /*(OrtSessionOptions*) */ options, IntPtr /*(char char*)*/ settings);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tvm(IntPtr /*(OrtSessionOptions*) */ options, byte[] /*(char char*)*/ settings);
#endif
/// <summary>
/// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance
Expand Down Expand Up @@ -1003,7 +1005,7 @@ internal class NativeLib
/// <param name="dimValue">Dimension value</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ dimDenotation,
byte[] /*(const char*)*/ dimDenotation,
long dimValue);

public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride;
Expand All @@ -1016,7 +1018,7 @@ internal class NativeLib
/// <param name="dimValue">Dimension value</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ dimName,
byte[] /*(const char*)*/ dimName,
long dimValue);

public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName;
Expand All @@ -1029,7 +1031,7 @@ internal class NativeLib
/// <param name="libraryHandle">(out) Native library handle</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options,
IntPtr /*(const char*)*/ libraryPath,
byte[] /*(const char*)*/ libraryPath,
out IntPtr /*(void**)*/ libraryHandle);

public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary;
Expand All @@ -1042,7 +1044,7 @@ internal class NativeLib
/// <param name="ortValue">Native OrtValue instnce</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ name,
byte[] /*(const char*)*/ name,
IntPtr /*(OrtValue*)*/ ortValue);

public static DOrtAddInitializer OrtAddInitializer;
Expand All @@ -1062,7 +1064,7 @@ internal class NativeLib
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider(
IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const char*)*/ providerName,
byte[] /*(const char*)*/ providerName,
IntPtr[] /*(const char* const *)*/ providerOptionsKeys,
IntPtr[] /*(const char* const *)*/ providerOptionsValues,
UIntPtr /*(size_t)*/ numKeys);
Expand Down Expand Up @@ -1090,7 +1092,7 @@ internal class NativeLib
public static DOrtRunOptionsSetRunLogSeverityLevel OrtRunOptionsSetRunLogSeverityLevel;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsSetRunTag(IntPtr /* OrtRunOptions* */ options, IntPtr /* const char* */ runTag);
public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsSetRunTag(IntPtr /* OrtRunOptions* */ options, byte[] /* const char* */ runTag);
public static DOrtRunOptionsSetRunTag OrtRunOptionsSetRunTag;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
Expand Down Expand Up @@ -1124,8 +1126,8 @@ internal class NativeLib
/// <param name="configValue">Config value</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtAddRunConfigEntry(IntPtr /* OrtRunOptions* */ options,
IntPtr /* const char* */configKey,
IntPtr /* const char* */ configValue);
byte[] /* const char* */configKey,
byte[] /* const char* */ configValue);
public static DOrtAddRunConfigEntry OrtAddRunConfigEntry;

#endregion
Expand All @@ -1134,7 +1136,7 @@ internal class NativeLib

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfo(
IntPtr /*(const char*) */name,
byte[] /*(const char*) */name,
OrtAllocatorType allocatorType,
int identifier,
OrtMemType memType,
Expand Down Expand Up @@ -1305,7 +1307,7 @@ internal class NativeLib
/// The param instance is copied internally so this argument may be released.
/// </param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus*/ DOrtBindInput(IntPtr /*(OrtIoBinding)*/ io_binding, IntPtr /*(const char*)*/ name, IntPtr /*const OrtValue**/ ort_value);
public delegate IntPtr /* OrtStatus*/ DOrtBindInput(IntPtr /*(OrtIoBinding)*/ io_binding, byte[] /*(const char*)*/ name, IntPtr /*const OrtValue**/ ort_value);

public static DOrtBindInput OrtBindInput;

Expand Down Expand Up @@ -1333,7 +1335,7 @@ internal class NativeLib
/// The param instance is copied internally so this argument may be released.
/// </param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus*/ DOrtBindOutput(IntPtr /*(OrtIoBinding)*/ io_binding, IntPtr /*(const char*) */ name, IntPtr /*const OrtValue**/ ort_value);
public delegate IntPtr /* OrtStatus*/ DOrtBindOutput(IntPtr /*(OrtIoBinding)*/ io_binding, byte[] /*(const char*) */ name, IntPtr /*const OrtValue**/ ort_value);

public static DOrtBindOutput OrtBindOutput;

Expand All @@ -1347,7 +1349,7 @@ internal class NativeLib
/// <param name="mem_info">OrtMemoryInfo instance that contains device id. May be obtained from the device specific allocator instance</param>
/// <returns></returns>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus*/ DOrtBindOutputToDevice(IntPtr /*(OrtIoBinding)*/ io_binding, IntPtr /*(const char*) */ name, IntPtr /* const OrtMemoryInfo */ mem_info);
public delegate IntPtr /* OrtStatus*/ DOrtBindOutputToDevice(IntPtr /*(OrtIoBinding)*/ io_binding, byte[] /*(const char*) */ name, IntPtr /* const OrtMemoryInfo */ mem_info);

public static DOrtBindOutputToDevice OrtBindOutputToDevice;

Expand Down Expand Up @@ -1657,12 +1659,21 @@ internal class NativeLib

public static DOrtFillStringTensor OrtFillStringTensor;

public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementBuffer(
IntPtr /* OrtValue */ value,
UIntPtr /* size_t */ index,
UIntPtr /* size_t */ length_in_bytes,
out IntPtr /* char** */ buffer
);

public static DOrtGetStringTensorElementBuffer OrtGetStringTensorElementBuffer;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent(
IntPtr /*(OrtValue*)*/ value,
IntPtr /*(void*)*/ dst_buffer,
byte[] /*(void*)*/ dst_buffer,
UIntPtr dst_buffer_len,
IntPtr offsets,
UIntPtr[] offsets,
UIntPtr offsets_len);

public static DOrtGetStringTensorContent OrtGetStringTensorContent;
Expand Down
Loading