Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C#] Add missing rocm csharp api #15540

Merged
merged 9 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ public struct OrtApi
public IntPtr GetOptionalContainedTypeInfo;
public IntPtr GetResizedStringTensorElementBuffer;
public IntPtr KernelContext_GetAllocator;
public IntPtr GetBuildInfoString;
Copy link
Member Author

@cloudhan cloudhan May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backlog: add GetBuildInfoString stub to avoid rocm apis off-by-one problem. #Resolved

public IntPtr CreateROCMProviderOptions;
public IntPtr UpdateROCMProviderOptions;
public IntPtr GetROCMProviderOptionsAsString;
public IntPtr ReleaseROCMProviderOptions;
}

internal static class NativeMethods
Expand Down Expand Up @@ -493,6 +498,12 @@ static NativeMethods()
api_.SessionOptionsAppendExecutionProvider,
typeof(DSessionOptionsAppendExecutionProvider));
OrtUpdateEnvWithCustomLogLevel = (DOrtUpdateEnvWithCustomLogLevel)Marshal.GetDelegateForFunctionPointer(api_.UpdateEnvWithCustomLogLevel, typeof(DOrtUpdateEnvWithCustomLogLevel));
SessionOptionsAppendExecutionProvider_ROCM = (DSessionOptionsAppendExecutionProvider_ROCM)Marshal.GetDelegateForFunctionPointer(
api_.SessionOptionsAppendExecutionProvider_ROCM, typeof(DSessionOptionsAppendExecutionProvider_ROCM));
OrtCreateROCMProviderOptions = (DOrtCreateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateROCMProviderOptions, typeof(DOrtCreateROCMProviderOptions));
OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions));
OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString));
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
}

internal class NativeLib
Expand Down Expand Up @@ -659,6 +670,51 @@ internal class NativeLib
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseCUDAProviderOptions(IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance);
public static DOrtReleaseCUDAProviderOptions OrtReleaseCUDAProviderOptions;

/// <summary>
/// Creates native OrtROCMProviderOptions instance
/// </summary>
/// <param name="rocmProviderOptionsInstance">(output) native instance of OrtROCMProviderOptions</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtCreateROCMProviderOptions(
out IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance);
public static DOrtCreateROCMProviderOptions OrtCreateROCMProviderOptions;

/// <summary>
/// Updates native OrtROCMProviderOptions instance using given key/value pairs
/// </summary>
/// <param name="rocmProviderOptionsInstance">native instance of OrtROCMProviderOptions</param>
/// <param name="providerOptionsKeys">configuration keys of OrtROCMProviderOptions</param>
/// <param name="providerOptionsValues">configuration values of OrtROCMProviderOptions</param>
/// <param name="numKeys">number of configuration keys</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtUpdateROCMProviderOptions(
IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance,
IntPtr[] /*(const char* const *)*/ providerOptionsKeys,
IntPtr[] /*(const char* const *)*/ providerOptionsValues,
UIntPtr /*(size_t)*/ numKeys);
public static DOrtUpdateROCMProviderOptions OrtUpdateROCMProviderOptions;

/// <summary>
/// Get native OrtROCMProviderOptions in serialized string
/// </summary>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="ptr">is a UTF-8 null terminated string allocated using 'allocator'</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtGetROCMProviderOptionsAsString(
IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/ptr);
public static DOrtGetROCMProviderOptionsAsString OrtGetROCMProviderOptionsAsString;

/// <summary>
/// Releases native OrtROCMProviderOptions instance
/// </summary>
/// <param name="rocmProviderOptionsInstance">native instance of OrtROCMProviderOptions to be released</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance);
public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions;

#endregion

#region Status API
Expand Down Expand Up @@ -1040,6 +1096,18 @@ internal class NativeLib

public static DSessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsAppendExecutionProvider_CUDA_V2;

/// <summary>
/// Append a ROCm EP instance (configured based on given provider options) to the native OrtSessionOptions instance
/// </summary>
/// <param name="options">Native OrtSessionOptions instance</param>
/// <param name="rocmProviderOptions">Native OrtROCMProviderOptions instance</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_ROCM(
IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions);

public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM;

/// <summary>
/// Free Dimension override (by denotation)
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ internal OrtMemoryInfo(IntPtr allocInfo, bool owned)
/// </summary>
public static readonly byte[] allocatorCUDA_PINNED = Encoding.UTF8.GetBytes("CudaPinned" + Char.MinValue);
/// <summary>
/// Predefined utf8 encoded allocator names. Use them to construct an instance of
/// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs.
/// </summary>
public static readonly byte[] allocatorHIP = Encoding.UTF8.GetBytes("Hip" + Char.MinValue);
/// <summary>
/// Predefined utf8 encoded allocator names. Use them to construct an instance of
/// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs.
/// </summary>
public static readonly byte[] allocatorHIP_PINNED = Encoding.UTF8.GetBytes("HipPinned" + Char.MinValue);
/// <summary>
/// Create an instance of OrtMemoryInfo according to the specification
/// Memory info instances are usually used to get a handle of a native allocator
/// that is present within the current inference session object. That, in turn, depends
Expand Down
90 changes: 88 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle,
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle,
allocator.Pointer, out IntPtr providerOptions));
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator);
}
Expand Down Expand Up @@ -151,7 +151,7 @@ public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle,
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle,
allocator.Pointer, out IntPtr providerOptions));
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator);
}
Expand Down Expand Up @@ -206,6 +206,92 @@ protected override bool ReleaseHandle()
}


/// <summary>
/// Holds the options for configuring a ROCm Execution Provider instance
/// </summary>
public class OrtROCMProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}


#region Constructor

/// <summary>
/// Constructs an empty OrtROCMroviderOptions instance
/// </summary>
public OrtROCMProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateROCMProviderOptions(out handle));
}

#endregion

#region Public Methods

/// <summary>
/// Get ROCm EP provider options
/// </summary>
/// <returns> return C# UTF-16 encoded string </returns>
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetROCMProviderOptionsAsString(handle,
allocator.Pointer, out IntPtr providerOptions));
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator);
}

private static IntPtr UpdateROCMProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count)
{
return NativeMethods.OrtUpdateROCMProviderOptions(handle, keys, values, count);
}

/// <summary>
/// Updates the configuration knobs of OrtROCMProviderOptions that will eventually be used to configure a ROCm EP
/// Please refer to the following on different key/value pairs to configure a ROCm EP and their meaning:
/// https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html
/// </summary>
/// <param name="providerOptions">key/value pairs used to configure a ROCm Execution Provider</param>
public void UpdateOptions(Dictionary<string, string> providerOptions)
{
ProviderOptionsUpdater.Update(providerOptions, handle, UpdateROCMProviderOptions);
}

#endregion

#region Public Properties

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }

#endregion

#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtROCMProviderOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseROCMProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}

#endregion
}


/// <summary>
/// This helper class contains methods to handle values of provider options
/// </summary>
Expand Down
51 changes: 49 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,33 @@ public static SessionOptions MakeSessionOptionWithTvmProvider(String settings =
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0)
{
CheckRocmExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
options.AppendExecutionProvider_ROCM(deviceId);
options.AppendExecutionProvider_ROCm(deviceId);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
return options;
}

/// <summary>
/// A helper method to construct a SessionOptions object for ROCm execution provider.
/// Use only if ROCm is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="rocmProviderOptions">ROCm EP provider options</param>
/// <returns>A SessionsOptions() object configured for execution on provider options</returns>
public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOptions rocmProviderOptions)
{
CheckRocmExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
options.AppendExecutionProvider_ROCm(rocmProviderOptions);
return options;
}
catch (Exception)
{
options.Dispose();
throw;
}
}
#endregion

#region ExecutionProviderAppends
Expand Down Expand Up @@ -276,7 +299,7 @@ public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trtProvi
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">Device Id</param>
public void AppendExecutionProvider_ROCM(int deviceId = 0)
public void AppendExecutionProvider_ROCm(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException("The ROCM Execution Provider is not supported in this build");
Expand All @@ -286,6 +309,20 @@ public void AppendExecutionProvider_ROCM(int deviceId = 0)
#endif
}

/// <summary>
/// Append a ROCm EP instance (based on specified configuration) to the SessionOptions instance.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="rocmProviderOptions">ROCm EP provider options</param>
public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOptions)
{
#if __MOBILE__
throw new NotSupportedException("The ROCm Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_ROCM(handle, rocmProviderOptions.Handle));
#endif
}

/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
Expand Down Expand Up @@ -833,6 +870,16 @@ private static bool CheckTensorrtExecutionProviderDLLs()
}
return true;
}

private static bool CheckRocmExecutionProviderDLLs()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
throw new NotSupportedException("ROCm Execution Provider is not currently supported on Windows.");
}
return true;
}

#endregion

#region SafeHandle
Expand Down
2 changes: 2 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ if [ $RunTestCsharp = "true" ]; then
exit 1
fi
dotnet test -p:DefineConstants=USE_TENSORRT $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed
elif [ $PACKAGENAME = "Microsoft.ML.OnnxRuntime.ROCm" ]; then
dotnet test -p:DefineConstants=USE_ROCM $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed
else
dotnet test $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void TestSessionOptions()
#endif

#if USE_ROCM
opt.AppendExecutionProvider_ROCM(0);
opt.AppendExecutionProvider_ROCm(0);
#endif

#if USE_TENSORRT
Expand Down Expand Up @@ -1655,9 +1655,9 @@ void TestCUDAAllocatorInternal(InferenceSession session)
void TestROCMAllocatorInternal(InferenceSession session)
{
int device_id = 0;
using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default))
using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorHIP, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default))
{
Assert.Equal("Rocm", info_rocm.Name);
Assert.Equal("Hip", info_rocm.Name);
Assert.Equal(device_id, info_rocm.Id);
Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType());
Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType());
Expand Down Expand Up @@ -1690,7 +1690,7 @@ private void TestAllocator()
#endif

#if USE_ROCM
options.AppendExecutionProvider_ROCM(0);
options.AppendExecutionProvider_ROCm(0);
#endif

using (var session = new InferenceSession(model, options))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,10 @@ private void TestModelSerialization()
}
}

// TestGpu() will test the CUDA EP on CUDA enabled builds and
// the DML EP on DML enabled builds
// TestGpu() will test
// - the CUDA EP on CUDA enabled builds
// - the DML EP on DML enabled builds
// - the ROCm EP on ROCm enabled builds
[GpuFact(DisplayName = "TestGpu")]
private void TestGpu()
{
Expand Down
Loading