From 9905ea39900e03aea6dc934a415fc42d129e35ac Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 16 May 2023 03:43:38 +0000 Subject: [PATCH 1/9] Add C# binding for rocm ep and add missing api from C++ side --- .../NativeMethods.shared.cs | 68 ++++++++++ .../ProviderOptions.shared.cs | 90 ++++++++++++- .../SessionOptions.shared.cs | 56 +++++++- .../core/session/onnxruntime_c_api.h | 61 +++++++++ .../rocm/rocm_execution_provider_info.cc | 14 ++ .../rocm/rocm_execution_provider_info.h | 1 + .../providers/rocm/rocm_provider_factory.cc | 21 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 11 +- onnxruntime/core/session/ort_apis.h | 8 ++ .../core/session/provider_bridge_ort.cc | 122 ++++++++++++++++++ 10 files changed, 446 insertions(+), 6 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 5fcddf7d0cea..9522906f4a21 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -287,6 +287,11 @@ public struct OrtApi public IntPtr GetOptionalContainedTypeInfo; public IntPtr GetResizedStringTensorElementBuffer; public IntPtr KernelContext_GetAllocator; + public IntPtr GetBuildInfoString; + public IntPtr CreateROCMProviderOptions; + public IntPtr UpdateROCMProviderOptions; + public IntPtr GetROCMProviderOptionsAsString; + public IntPtr ReleaseROCMProviderOptions; } internal static class NativeMethods @@ -493,6 +498,12 @@ static NativeMethods() api_.SessionOptionsAppendExecutionProvider, typeof(DSessionOptionsAppendExecutionProvider)); OrtUpdateEnvWithCustomLogLevel = (DOrtUpdateEnvWithCustomLogLevel)Marshal.GetDelegateForFunctionPointer(api_.UpdateEnvWithCustomLogLevel, typeof(DOrtUpdateEnvWithCustomLogLevel)); + SessionOptionsAppendExecutionProvider_ROCM = (DSessionOptionsAppendExecutionProvider_ROCM)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_ROCM, typeof(DSessionOptionsAppendExecutionProvider_ROCM)); + OrtCreateROCMProviderOptions = (DOrtCreateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateROCMProviderOptions, typeof(DOrtCreateROCMProviderOptions)); + OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); + OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); + OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); } internal class NativeLib @@ -659,6 +670,51 @@ internal class NativeLib [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseCUDAProviderOptions(IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance); public static DOrtReleaseCUDAProviderOptions OrtReleaseCUDAProviderOptions; + + /// + /// Creates native OrtROCMProviderOptions instance + /// + /// (output) native instance of OrtROCMProviderOptions + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */DOrtCreateROCMProviderOptions( + out IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance); + public static DOrtCreateROCMProviderOptions OrtCreateROCMProviderOptions; + + /// + /// Updates native OrtROCMProviderOptions instance using given key/value pairs + /// + /// native instance of OrtROCMProviderOptions + /// configuration keys of OrtROCMProviderOptions + /// configuration values of OrtROCMProviderOptions + /// number of configuration keys + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */DOrtUpdateROCMProviderOptions( + IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance, + IntPtr[] /*(const char* const *)*/ providerOptionsKeys, + IntPtr[] /*(const char* const *)*/ providerOptionsValues, + UIntPtr /*(size_t)*/ numKeys); + public static DOrtUpdateROCMProviderOptions OrtUpdateROCMProviderOptions; + + /// + /// Get native OrtROCMProviderOptions in serialized string + /// + /// instance of OrtAllocator + /// is a UTF-8 null terminated string allocated using 'allocator' + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */DOrtGetROCMProviderOptionsAsString( + IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ptr); + public static DOrtGetROCMProviderOptionsAsString OrtGetROCMProviderOptionsAsString; + + /// + /// Releases native OrtROCMProviderOptions instance + /// + /// native instance of OrtROCMProviderOptions to be released + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance); + public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions; + #endregion #region Status API @@ -1040,6 +1096,18 @@ internal class NativeLib public static DSessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsAppendExecutionProvider_CUDA_V2; + /// + /// Append a ROCm EP instance (configured based on given provider options) to the native OrtSessionOptions instance + /// + /// Native OrtSessionOptions instance + /// Native OrtROCMProviderOptions instance + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_ROCM( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions); + + public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; + /// /// Free Dimension override (by denotation) /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs index da46c5f9430f..c2f1de6a289e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs @@ -46,7 +46,7 @@ public string GetOptions() { var allocator = OrtAllocator.DefaultInstance; // Process provider options string - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle, + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle, allocator.Pointer, out IntPtr providerOptions)); return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); } @@ -151,7 +151,7 @@ public string GetOptions() { var allocator = OrtAllocator.DefaultInstance; // Process provider options string - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle, + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle, allocator.Pointer, out IntPtr providerOptions)); return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); } @@ -206,6 +206,92 @@ protected override bool ReleaseHandle() } + /// + /// Holds the options for configuring a ROCm Execution Provider instance + /// + public class OrtROCMProviderOptions : SafeHandle + { + internal IntPtr Handle + { + get + { + return handle; + } + } + + + #region Constructor + + /// + /// Constructs an empty OrtROCMroviderOptions instance + /// + public OrtROCMProviderOptions() : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateROCMProviderOptions(out handle)); + } + + #endregion + + #region Public Methods + + /// + /// Get ROCm EP provider options + /// + /// return C# UTF-16 encoded string + public string GetOptions() + { + var allocator = OrtAllocator.DefaultInstance; + // Process provider options string + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetROCMProviderOptionsAsString(handle, + allocator.Pointer, out IntPtr providerOptions)); + return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); + } + + private static IntPtr UpdateROCMProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count) + { + return NativeMethods.OrtUpdateROCMProviderOptions(handle, keys, values, count); + } + + /// + /// Updates the configuration knobs of OrtROCMProviderOptions that will eventually be used to configure a ROCm EP + /// Please refer to the following on different key/value pairs to configure a ROCm EP and their meaning: + /// https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html + /// + /// key/value pairs used to configure a ROCm Execution Provider + public void UpdateOptions(Dictionary providerOptions) + { + ProviderOptionsUpdater.Update(providerOptions, handle, UpdateROCMProviderOptions); + } + + #endregion + + #region Public Properties + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + #endregion + + #region SafeHandle + /// + /// Overrides SafeHandle.ReleaseHandle() to properly dispose of + /// the native instance of OrtROCMProviderOptions + /// + /// always returns true + protected override bool ReleaseHandle() + { + NativeMethods.OrtReleaseROCMProviderOptions(handle); + handle = IntPtr.Zero; + return true; + } + + #endregion + } + + /// /// This helper class contains methods to handle values of provider options /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 30951bae3f9f..5c71fc153fb5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -161,10 +161,33 @@ public static SessionOptions MakeSessionOptionWithTvmProvider(String settings = /// A SessionsOptions() object configured for execution on deviceId public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0) { + CheckRocmExecutionProviderDLLs(); SessionOptions options = new SessionOptions(); - options.AppendExecutionProvider_ROCM(deviceId); + options.AppendExecutionProvider_ROCm(deviceId); return options; } + + /// + /// A helper method to construct a SessionOptions object for ROCm execution provider. + /// Use only if ROCm is installed and you have the onnxruntime package specific to this Execution Provider. + /// + /// ROCm EP provider options + /// A SessionsOptions() object configured for execution on provider options + public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOptions rocmProviderOptions) + { + CheckRocmExecutionProviderDLLs(); + SessionOptions options = new SessionOptions(); + try + { + options.AppendExecutionProvider_ROCm(rocmProviderOptions); + return options; + } + catch (Exception) + { + options.Dispose(); + throw; + } + } #endregion #region ExecutionProviderAppends @@ -272,11 +295,16 @@ public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trtProvi #endif } + [Obsolete("AppendExecutionProvider_ROCM is deprecated, use AppendExecutionProvider_ROCm instead.")] + public void AppendExecutionProvider_ROCM(int deviceId = 0) { + AppendExecutionProvider_ROCm(deviceId); + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// /// Device Id - public void AppendExecutionProvider_ROCM(int deviceId = 0) + public void AppendExecutionProvider_ROCm(int deviceId = 0) { #if __MOBILE__ throw new NotSupportedException("The ROCM Execution Provider is not supported in this build"); @@ -286,6 +314,20 @@ public void AppendExecutionProvider_ROCM(int deviceId = 0) #endif } + /// + /// Append a ROCm EP instance (based on specified configuration) to the SessionOptions instance. + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + /// ROCm EP provider options + public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOptions) + { +#if __MOBILE__ + throw new NotSupportedException("The ROCm Execution Provider is not supported in this build"); +#else + NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_ROCM(handle, rocmProviderOptions.Handle)); +#endif + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// @@ -833,6 +875,16 @@ private static bool CheckTensorrtExecutionProviderDLLs() } return true; } + + private static bool CheckRocmExecutionProviderDLLs() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + throw new NotSupportedException("ROCm Execution Provider is not currently supported on Windows."); + } + return true; + } + #endregion #region SafeHandle diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 292fe71a42b9..e554da6eeb19 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4211,6 +4211,67 @@ struct OrtApi { * \since Version 1.15. */ const char*(ORT_API_CALL* GetBuildInfoString)(void); + + /// \name OrtROCMProviderOptions + /// @{ + + /** \brief Create an OrtROCMProviderOptions + * + * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); + + /** \brief Set options in a ROCm Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html + * to know the available keys and values. Key should be in null terminated string format of the member of + * ::OrtROCMProviderOptions and value should be its related range. + * + * For example, key="device_id" and value="0" + * + * \param[in] rocm_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized ROCm provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param rocm_options - OrtROCMProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtROCMProviderOptions + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.16. + */ + void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input); + + /// @} }; /* diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 5e32ae3067f1..2901c1a83d0a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -121,4 +121,18 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution return options; } +ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { + const ProviderOptions options{ + {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, + {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, + {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, + {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, + {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, + {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, + }; + + return options; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index 7666b990400d..5b1fd095891a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -66,6 +66,7 @@ struct ROCMExecutionProviderInfo { static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const ROCMExecutionProviderInfo& info); + static ProviderOptions ToProviderOptions(const OrtROCMProviderOptions& info); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 0eaf2148875d..2e8ee8dd5341 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -179,6 +179,27 @@ struct ROCM_Provider : Provider { return std::make_shared(info); } + void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { + auto info = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); + auto& rocm_options = *reinterpret_cast(provider_options); + + rocm_options.device_id = info.device_id; + rocm_options.gpu_mem_limit = info.gpu_mem_limit; + rocm_options.arena_extend_strategy = static_cast(info.arena_extend_strategy); + rocm_options.miopen_conv_exhaustive_search = info.miopen_conv_exhaustive_search; + rocm_options.do_copy_in_default_stream = info.do_copy_in_default_stream; + rocm_options.has_user_compute_stream = info.has_user_compute_stream; + rocm_options.user_compute_stream = info.user_compute_stream; + rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg; + rocm_options.tunable_op_enable = info.tunable_op.enable; + rocm_options.tunable_op_tuning_enable = info. tunable_op.tuning_enable; + } + + ProviderOptions GetProviderOptions(const void* provider_options) override { + auto& options = *reinterpret_cast(provider_options); + return onnxruntime::ROCMExecutionProviderInfo::ToProviderOptions(options); + } + void Initialize() override { InitializeRegistry(); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 53af75a2c688..708cb825a661 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2725,8 +2725,15 @@ static constexpr OrtApi ort_api_1_to_16 = { &OrtApis::GetOptionalContainedTypeInfo, &OrtApis::GetResizedStringTensorElementBuffer, &OrtApis::KernelContext_GetAllocator, - &OrtApis::GetBuildInfoString}; -// End of Version 15 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::GetBuildInfoString, + // End of Version 15 - DO NOT MODIFY ABOVE (see above text for more information) + + // Start of Version 16 API in progress, safe to modify/rename/rearrange until we ship + &OrtApis::CreateROCMProviderOptions, + &OrtApis::UpdateROCMProviderOptions, + &OrtApis::GetROCMProviderOptionsAsString, + &OrtApis::ReleaseROCMProviderOptions, +}; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. static_assert(sizeof(OrtApiBase) == sizeof(void*) * 2, "New methods can't be added to OrtApiBase as it is not versioned"); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 8b8ffd38d697..f5b1a0f505da 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -467,4 +467,12 @@ ORT_API_STATUS_IMPL(GetResizedStringTensorElementBuffer, _Inout_ OrtValue* value ORT_API_STATUS_IMPL(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); ORT_API(const char*, GetBuildInfoString); + +ORT_API_STATUS_IMPL(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); +ORT_API_STATUS_IMPL(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys); +ORT_API_STATUS_IMPL(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); +ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions*); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7d92cc3f7e03..f0e431ead441 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1466,6 +1466,14 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op return s_library_cuda.Get().GetProviderOptions(reinterpret_cast(provider_options)); } +void UpdateProviderInfo_Rocm(OrtROCMProviderOptions* provider_options, const ProviderOptions& options) { + return s_library_rocm.Get().UpdateProviderOptions(reinterpret_cast(provider_options), options); +} + +ProviderOptions GetProviderInfo_Rocm(const OrtROCMProviderOptions* provider_options) { + return s_library_rocm.Get().GetProviderOptions(reinterpret_cast(provider_options)); +} + } // namespace onnxruntime ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { @@ -2091,3 +2099,117 @@ ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProvid ORT_UNUSED_PARAMETER(ptr); #endif } + +ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { + API_IMPL_BEGIN +#ifdef USE_ROCM + +// Need to use 'new' here, so disable C26409 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26409) +#endif + *out = new OrtROCMProviderOptions(); +#ifdef _WIN32 +#pragma warning(pop) +#endif + (*out)->device_id = 0; + (*out)->miopen_conv_exhaustive_search = 0; + (*out)->gpu_mem_limit = std::numeric_limits::max(); + (*out)->arena_extend_strategy = 0; + (*out)->do_copy_in_default_stream = 1; + (*out)->has_user_compute_stream = 0; + (*out)->user_compute_stream = nullptr; + (*out)->default_memory_arena_cfg = nullptr; + (*out)->tunable_op_enable = 0; + (*out)->tunable_op_tuning_enable = 0; + return nullptr; +#else + ORT_UNUSED_PARAMETER(out); + return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateROCMProviderOptions, + _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + API_IMPL_BEGIN +#ifdef USE_ROCM + onnxruntime::ProviderOptions provider_options_map; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty"); + } + + provider_options_map[provider_options_keys[i]] = provider_options_values[i]; + } + + onnxruntime::UpdateProviderInfo_Rocm(rocm_options, + reinterpret_cast(provider_options_map)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(rocm_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, + _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr) { + API_IMPL_BEGIN +#ifdef USE_ROCM + onnxruntime::ProviderOptions options = onnxruntime::GetProviderInfo_Rocm(rocm_options); + onnxruntime::ProviderOptions::iterator it = options.begin(); + std::string options_str = ""; + + while (it != options.end()) { + if (options_str == "") { + options_str += it->first; + options_str += "="; + options_str += it->second; + } else { + options_str += ";"; + options_str += it->first; + options_str += "="; + options_str += it->second; + } + it++; + } + + *ptr = onnxruntime::StrDup(options_str, allocator); + return nullptr; +#else + ORT_UNUSED_PARAMETER(rocm_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { +#ifdef USE_ROCM + +// Need to use 'delete' here, so disable C26409 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26409) +#endif + + delete ptr; + +#ifdef _WIN32 +#pragma warning(pop) +#endif + +#else + ORT_UNUSED_PARAMETER(ptr); +#endif +} From 58f72345522c51c5d73e54f5d367a2721509ed95 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 17 Apr 2023 10:49:02 +0000 Subject: [PATCH 2/9] Rename AppendExecutionProvider_ROCM, in favor of AppendExecutionProvider_ROCm --- csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs | 5 ----- .../Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 5c71fc153fb5..067e7804c86a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -295,11 +295,6 @@ public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trtProvi #endif } - [Obsolete("AppendExecutionProvider_ROCM is deprecated, use AppendExecutionProvider_ROCm instead.")] - public void AppendExecutionProvider_ROCM(int deviceId = 0) { - AppendExecutionProvider_ROCm(deviceId); - } - /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 5ebca9b05bb1..9d56b2f1593d 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -147,7 +147,7 @@ public void TestSessionOptions() #endif #if USE_ROCM - opt.AppendExecutionProvider_ROCM(0); + opt.AppendExecutionProvider_ROCm(0); #endif #if USE_TENSORRT @@ -1690,7 +1690,7 @@ private void TestAllocator() #endif #if USE_ROCM - options.AppendExecutionProvider_ROCM(0); + options.AppendExecutionProvider_ROCm(0); #endif using (var session = new InferenceSession(model, options)) From b57f865fee0c8cef0b5e14236db2d370b2d3ad0b Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 18 Apr 2023 06:22:30 +0000 Subject: [PATCH 3/9] Use unique_ptr to avoid new/delete --- .../core/session/provider_bridge_ort.cc | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index f0e431ead441..6b18170df4b9 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2103,16 +2103,7 @@ ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProvid ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { API_IMPL_BEGIN #ifdef USE_ROCM - -// Need to use 'new' here, so disable C26409 -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 26409) -#endif - *out = new OrtROCMProviderOptions(); -#ifdef _WIN32 -#pragma warning(pop) -#endif + *out = std::make_unique().release(); (*out)->device_id = 0; (*out)->miopen_conv_exhaustive_search = 0; (*out)->gpu_mem_limit = std::numeric_limits::max(); @@ -2196,19 +2187,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, _In_ const OrtROCMP ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { #ifdef USE_ROCM - -// Need to use 'delete' here, so disable C26409 -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 26409) -#endif - - delete ptr; - -#ifdef _WIN32 -#pragma warning(pop) -#endif - + std::unique_ptr p(ptr); #else ORT_UNUSED_PARAMETER(ptr); #endif From d3aabaae644981587f776e2787348a187e723f0a Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 18 Apr 2023 09:50:55 +0000 Subject: [PATCH 4/9] Add HIP and HIP_PINNED and fix OrtMemoryInfo creation --- .../Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs | 10 ++++++++++ .../Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh | 2 ++ .../InferenceTest.cs | 4 ++-- onnxruntime/core/framework/allocator.cc | 8 ++++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs index 688bdf630fb7..9093d5a65293 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs @@ -163,6 +163,16 @@ internal OrtMemoryInfo(IntPtr allocInfo, bool owned) /// public static readonly byte[] allocatorCUDA_PINNED = Encoding.UTF8.GetBytes("CudaPinned" + Char.MinValue); /// + /// Predefined utf8 encoded allocator names. Use them to construct an instance of + /// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs. + /// + public static readonly byte[] allocatorHIP = Encoding.UTF8.GetBytes("Hip" + Char.MinValue); + /// + /// Predefined utf8 encoded allocator names. Use them to construct an instance of + /// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs. + /// + public static readonly byte[] allocatorHIP_PINNED = Encoding.UTF8.GetBytes("HipPinned" + Char.MinValue); + /// /// Create an instance of OrtMemoryInfo according to the specification /// Memory info instances are usually used to get a handle of a native allocator /// that is present within the current inference session object. That, in turn, depends diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh index f56fdc802cd4..e7293bedc0e4 100755 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh @@ -40,6 +40,8 @@ if [ $RunTestCsharp = "true" ]; then exit 1 fi dotnet test -p:DefineConstants=USE_TENSORRT $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed + elif [ $PACKAGENAME = "Microsoft.ML.OnnxRuntime.ROCm" ]; then + dotnet test -p:DefineConstants=USE_ROCM $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed else dotnet test $BUILD_SOURCESDIRECTORY/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj --no-restore --verbosity detailed fi diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 9d56b2f1593d..2c9f92b3b8d6 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -1655,9 +1655,9 @@ void TestCUDAAllocatorInternal(InferenceSession session) void TestROCMAllocatorInternal(InferenceSession session) { int device_id = 0; - using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) + using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorHIP, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) { - Assert.Equal("Rocm", info_rocm.Name); + Assert.Equal("Hip", info_rocm.Name); Assert.Equal(device_id, info_rocm.Id); Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType()); Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType()); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 1d20f3b78082..af812bb8c941 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -154,6 +154,14 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo( onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::HIP) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::HIP, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); + } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), + id1, mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); } From 0ab74e62eb19c5bcac8f920114d36bf4270568d9 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 25 Apr 2023 05:38:17 +0000 Subject: [PATCH 5/9] Minor --- .../InferenceTest.netcore.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 0e1e5ea5ea1a..4a12ee8e0d79 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -953,8 +953,10 @@ private void TestModelSerialization() } } - // TestGpu() will test the CUDA EP on CUDA enabled builds and - // the DML EP on DML enabled builds + // TestGpu() will test + // - the CUDA EP on CUDA enabled builds + // - the DML EP on DML enabled builds + // - the ROCm EP on ROCm enabled builds [GpuFact(DisplayName = "TestGpu")] private void TestGpu() { From 95e0e4292a55b2360d0ac1ddd38839398545ce37 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 16 May 2023 03:58:13 +0000 Subject: [PATCH 6/9] Fix filename case --- tools/ci_build/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 54e844bc08af..6c179b99e432 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2045,7 +2045,7 @@ def build_nuget_package( ort_build_dir = '/p:OnnxRuntimeBuildDirectory="' + native_dir + '"' # dotnet restore - cmd_args = ["dotnet", "restore", sln, "--configfile", "Nuget.CSharp.config"] + cmd_args = ["dotnet", "restore", sln, "--configfile", "NuGet.CSharp.config"] run_subprocess(cmd_args, cwd=csharp_build_dir) # build csharp bindings and create nuget package for each config From 9cb0b2afa389a47080ccfb5f628962ee4b7410c1 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 16 May 2023 04:12:33 +0000 Subject: [PATCH 7/9] Fix format --- onnxruntime/core/providers/rocm/rocm_provider_factory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 2e8ee8dd5341..e1a07a348644 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -192,7 +192,7 @@ struct ROCM_Provider : Provider { rocm_options.user_compute_stream = info.user_compute_stream; rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg; rocm_options.tunable_op_enable = info.tunable_op.enable; - rocm_options.tunable_op_tuning_enable = info. tunable_op.tuning_enable; + rocm_options.tunable_op_tuning_enable = info.tunable_op.tuning_enable; } ProviderOptions GetProviderOptions(const void* provider_options) override { From 13ed9ce46ec780f23185d791295c96d2fb894ee0 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 16 May 2023 07:44:49 +0000 Subject: [PATCH 8/9] Add stub apis for min build --- .../core/session/provider_registration.cc | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 9254b0a296b9..248f32bc25c3 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -373,4 +373,35 @@ ORT_API_STATUS_IMPL(OrtApis::GetDnnlProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } + +ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { + ORT_UNUSED_PARAMETER(out); + return CreateNotEnabledStatus("ROCM"); +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateROCMProviderOptions, + _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + ORT_UNUSED_PARAMETER(rocm_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("ROCM"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, + _In_ const OrtROCMProviderOptions* rocm_options, + _Inout_ OrtAllocator* allocator, + _Outptr_ char** ptr) { + ORT_UNUSED_PARAMETER(rocm_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateNotEnabledStatus("ROCM"); +} + +ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { + ORT_UNUSED_PARAMETER(ptr); +} #endif From aa2590d283d29f2619868bb19c3634447e1365bf Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 17 May 2023 01:40:39 +0000 Subject: [PATCH 9/9] Address review comment --- .../SessionOptions.shared.cs | 24 +++++++---- .../core/session/provider_bridge_ort.cc | 41 ++++++++----------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 067e7804c86a..9380820ea0ce 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -163,8 +163,16 @@ public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0) { CheckRocmExecutionProviderDLLs(); SessionOptions options = new SessionOptions(); - options.AppendExecutionProvider_ROCm(deviceId); - return options; + try + { + options.AppendExecutionProvider_ROCm(deviceId); + return options; + } + catch (Exception) + { + options.Dispose(); + throw; + } } /// @@ -434,10 +442,10 @@ public void AppendExecutionProvider(string providerName, Dictionary /// path to the custom op library @@ -472,13 +480,13 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand // SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2 // SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, out libraryHandle)); } /// /// Register the custom operators from the Microsoft.ML.OnnxRuntime.Extensions NuGet package. - /// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project. + /// A reference to Microsoft.ML.OnnxRuntime.Extensions must be manually added to your project. /// /// Throws if the extensions library is not found. public void RegisterOrtExtensions() diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6b18170df4b9..94d4f4daa6ab 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1466,14 +1466,6 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op return s_library_cuda.Get().GetProviderOptions(reinterpret_cast(provider_options)); } -void UpdateProviderInfo_Rocm(OrtROCMProviderOptions* provider_options, const ProviderOptions& options) { - return s_library_rocm.Get().UpdateProviderOptions(reinterpret_cast(provider_options), options); -} - -ProviderOptions GetProviderInfo_Rocm(const OrtROCMProviderOptions* provider_options) { - return s_library_rocm.Get().GetProviderOptions(reinterpret_cast(provider_options)); -} - } // namespace onnxruntime ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { @@ -2103,17 +2095,19 @@ ORT_API(void, OrtApis::ReleaseDnnlProviderOptions, _Frees_ptr_opt_ OrtDnnlProvid ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out) { API_IMPL_BEGIN #ifdef USE_ROCM - *out = std::make_unique().release(); - (*out)->device_id = 0; - (*out)->miopen_conv_exhaustive_search = 0; - (*out)->gpu_mem_limit = std::numeric_limits::max(); - (*out)->arena_extend_strategy = 0; - (*out)->do_copy_in_default_stream = 1; - (*out)->has_user_compute_stream = 0; - (*out)->user_compute_stream = nullptr; - (*out)->default_memory_arena_cfg = nullptr; - (*out)->tunable_op_enable = 0; - (*out)->tunable_op_tuning_enable = 0; + auto options = std::make_unique(); + options->device_id = 0; + options->miopen_conv_exhaustive_search = 0; + options->gpu_mem_limit = std::numeric_limits::max(); + options->arena_extend_strategy = 0; + options->do_copy_in_default_stream = 1; + options->has_user_compute_stream = 0; + options->user_compute_stream = nullptr; + options->default_memory_arena_cfg = nullptr; + options->tunable_op_enable = 0; + options->tunable_op_tuning_enable = 0; + + *out = options.release(); return nullptr; #else ORT_UNUSED_PARAMETER(out); @@ -2139,15 +2133,14 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateROCMProviderOptions, provider_options_map[provider_options_keys[i]] = provider_options_values[i]; } - onnxruntime::UpdateProviderInfo_Rocm(rocm_options, - reinterpret_cast(provider_options_map)); + onnxruntime::s_library_rocm.Get().UpdateProviderOptions(rocm_options, provider_options_map); return nullptr; #else ORT_UNUSED_PARAMETER(rocm_options); ORT_UNUSED_PARAMETER(provider_options_keys); ORT_UNUSED_PARAMETER(provider_options_values); ORT_UNUSED_PARAMETER(num_keys); - return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build."); + return CreateStatus(ORT_FAIL, "ROCm execution provider is not enabled in this build."); #endif API_IMPL_END } @@ -2156,9 +2149,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, _In_ const OrtROCMP _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr) { API_IMPL_BEGIN #ifdef USE_ROCM - onnxruntime::ProviderOptions options = onnxruntime::GetProviderInfo_Rocm(rocm_options); + onnxruntime::ProviderOptions options = onnxruntime::s_library_rocm.Get().GetProviderOptions(rocm_options); onnxruntime::ProviderOptions::iterator it = options.begin(); - std::string options_str = ""; + std::string options_str; while (it != options.end()) { if (options_str == "") {