diff --git a/include/ur_api.h b/include/ur_api.h index 2f56b5dd16..d21e146fda 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -225,6 +225,7 @@ typedef enum ur_function_t { UR_FUNCTION_ENQUEUE_TIMESTAMP_RECORDING_EXP = 223, ///< Enumerator for ::urEnqueueTimestampRecordingExp UR_FUNCTION_ENQUEUE_KERNEL_LAUNCH_CUSTOM_EXP = 224, ///< Enumerator for ::urEnqueueKernelLaunchCustomExp UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE = 225, ///< Enumerator for ::urKernelGetSuggestedLocalWorkSize + UR_FUNCTION_LOADER_CONFIG_SET_MOCK_CALLBACKS = 229, ///< Enumerator for ::urLoaderConfigSetMockCallbacks /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -269,6 +270,7 @@ typedef enum ur_structure_type_t { UR_STRUCTURE_TYPE_KERNEL_ARG_VALUE_PROPERTIES = 32, ///< ::ur_kernel_arg_value_properties_t UR_STRUCTURE_TYPE_KERNEL_ARG_LOCAL_PROPERTIES = 33, ///< ::ur_kernel_arg_local_properties_t UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC = 35, ///< ::ur_usm_alloc_location_desc_t + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES = 38, ///< ::ur_mock_callback_properties_t UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC = 0x1000, ///< ::ur_exp_command_buffer_desc_t UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC = 0x1001, ///< ::ur_exp_command_buffer_update_kernel_launch_desc_t UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC = 0x1002, ///< ::ur_exp_command_buffer_update_memobj_arg_desc_t @@ -598,7 +600,7 @@ urLoaderConfigCreate( /// + `NULL == hLoaderConfig` UR_APIEXPORT ur_result_t UR_APICALL urLoaderConfigRetain( - ur_loader_config_handle_t hLoaderConfig ///< [in] loader config handle to retain + ur_loader_config_handle_t hLoaderConfig ///< [in][retain] loader config handle to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -619,7 +621,7 @@ urLoaderConfigRetain( /// + `NULL == hLoaderConfig` UR_APIEXPORT ur_result_t UR_APICALL urLoaderConfigRelease( - ur_loader_config_handle_t hLoaderConfig ///< [in] config handle to release + ur_loader_config_handle_t hLoaderConfig ///< [in][release] config handle to release ); /////////////////////////////////////////////////////////////////////////////// @@ -739,6 +741,65 @@ urLoaderConfigSetCodeLocationCallback( void *pUserData ///< [in][out][optional] pointer to data to be passed to callback. ); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Callback override mode +typedef enum ur_callback_override_mode_t { + UR_CALLBACK_OVERRIDE_MODE_BEFORE = 0, ///< Invoke callback before function. + UR_CALLBACK_OVERRIDE_MODE_REPLACE = 1, ///< Invoke callback instead of function. + UR_CALLBACK_OVERRIDE_MODE_AFTER = 2, ///< Invoke callback after function. + /// @cond + UR_CALLBACK_OVERRIDE_MODE_FORCE_UINT32 = 0x7fffffff + /// @endcond + +} ur_callback_override_mode_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Callback to replace or instrument generic mock functionality in the +/// mock layer. +typedef ur_result_t (*ur_mock_callback_t)( + void *pParams ///< [in][out] Pointer to the appropriate param struct for the function +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Callback properties +typedef struct ur_mock_callback_properties_t { + ur_structure_type_t stype; ///< [in] type of this structure, must be + ///< ::UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES + void *pNext; ///< [in,out][optional] pointer to extension-specific structure + const char *name; ///< [in] Full name of the function to associate callback with. + ur_callback_override_mode_t mode; ///< [in] Override mode for this callback. + ur_mock_callback_t pCallback; ///< [in] Callback function pointer. + +} ur_mock_callback_properties_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Set a callback to be called before, after or instead of a given entry +/// point +/// +/// @details +/// - The callback layer will pass the function's parameter struct (e.g. +/// **::ur_adapter_get_params_t**) to the ::ur_mock_callback_t so +/// parameters can be accessed and modified. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pCallbackProperties` +/// + `NULL == pCallbackProperties->name` +/// + `NULL == pCallbackProperties->pCallback` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_CALLBACK_OVERRIDE_MODE_AFTER < pCallbackProperties->mode` +UR_APIEXPORT ur_result_t UR_APICALL +urLoaderConfigSetMockCallbacks( + ur_loader_config_handle_t hLoaderConfig, ///< [in] Handle to config object the layer will be enabled for. + ur_mock_callback_properties_t *pCallbackProperties ///< [in] Pointer to callback properties struct. +); + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -842,7 +903,7 @@ urAdapterGet( /// + `NULL == hAdapter` UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ); /////////////////////////////////////////////////////////////////////////////// @@ -860,7 +921,7 @@ urAdapterRelease( /// + `NULL == hAdapter` UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -1724,7 +1785,7 @@ urDeviceGetInfo( /// + `NULL == hDevice` UR_APIEXPORT ur_result_t UR_APICALL urDeviceRetain( - ur_device_handle_t hDevice ///< [in] handle of the device to get a reference of. + ur_device_handle_t hDevice ///< [in][retain] handle of the device to get a reference of. ); /////////////////////////////////////////////////////////////////////////////// @@ -1752,7 +1813,7 @@ urDeviceRetain( /// + `NULL == hDevice` UR_APIEXPORT ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t hDevice ///< [in][release] handle of the device to release. ); /////////////////////////////////////////////////////////////////////////////// @@ -2205,7 +2266,7 @@ urContextCreate( /// + `NULL == hContext` UR_APIEXPORT ur_result_t UR_APICALL urContextRetain( - ur_context_handle_t hContext ///< [in] handle of the context to get a reference of. + ur_context_handle_t hContext ///< [in][retain] handle of the context to get a reference of. ); /////////////////////////////////////////////////////////////////////////////// @@ -2259,7 +2320,7 @@ typedef enum ur_context_info_t { /// + `NULL == hContext` UR_APIEXPORT ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t hContext ///< [in][release] handle of the context to release. ); /////////////////////////////////////////////////////////////////////////////// @@ -2723,7 +2784,7 @@ urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t hMem ///< [in][retain] handle of the memory object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -2745,7 +2806,7 @@ urMemRetain( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t hMem ///< [in][release] handle of the memory object to release ); /////////////////////////////////////////////////////////////////////////////// @@ -3111,7 +3172,7 @@ urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urSamplerRetain( - ur_sampler_handle_t hSampler ///< [in] handle of the sampler object to get access + ur_sampler_handle_t hSampler ///< [in][retain] handle of the sampler object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -3134,7 +3195,7 @@ urSamplerRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urSamplerRelease( - ur_sampler_handle_t hSampler ///< [in] handle of the sampler object to release + ur_sampler_handle_t hSampler ///< [in][release] handle of the sampler object to release ); /////////////////////////////////////////////////////////////////////////////// @@ -3671,7 +3732,7 @@ urUSMPoolCreate( /// + `NULL == pPool` UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ); /////////////////////////////////////////////////////////////////////////////// @@ -3693,7 +3754,7 @@ urUSMPoolRetain( /// + `NULL == pPool` UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ); /////////////////////////////////////////////////////////////////////////////// @@ -4027,7 +4088,7 @@ urPhysicalMemCreate( /// + `NULL == hPhysicalMem` UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRetain( - ur_physical_mem_handle_t hPhysicalMem ///< [in] handle of the physical memory object to retain. + ur_physical_mem_handle_t hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ); /////////////////////////////////////////////////////////////////////////////// @@ -4042,7 +4103,7 @@ urPhysicalMemRetain( /// + `NULL == hPhysicalMem` UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRelease( - ur_physical_mem_handle_t hPhysicalMem ///< [in] handle of the physical memory object to release. + ur_physical_mem_handle_t hPhysicalMem ///< [in][release] handle of the physical memory object to release. ); #if !defined(__GNUC__) @@ -4310,7 +4371,7 @@ urProgramLink( /// + `NULL == hProgram` UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t hProgram ///< [in][retain] handle for the Program to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -4335,7 +4396,7 @@ urProgramRetain( /// + `NULL == hProgram` UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t hProgram ///< [in][release] handle for the Program to release ); /////////////////////////////////////////////////////////////////////////////// @@ -4957,7 +5018,7 @@ urKernelGetSubGroupInfo( /// + `NULL == hKernel` UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ); /////////////////////////////////////////////////////////////////////////////// @@ -4982,7 +5043,7 @@ urKernelRetain( /// + `NULL == hKernel` UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t hKernel ///< [in][release] handle for the Kernel to release ); /////////////////////////////////////////////////////////////////////////////// @@ -5461,7 +5522,7 @@ urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t hQueue ///< [in][retain] handle of the queue object to get access ); /////////////////////////////////////////////////////////////////////////////// @@ -5490,7 +5551,7 @@ urQueueRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t hQueue ///< [in][release] handle of the queue object to release ); /////////////////////////////////////////////////////////////////////////////// @@ -5855,7 +5916,7 @@ urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ); /////////////////////////////////////////////////////////////////////////////// @@ -5878,7 +5939,7 @@ urEventRetain( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ); /////////////////////////////////////////////////////////////////////////////// @@ -7887,7 +7948,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object - ur_exp_interop_mem_handle_t hInteropMem ///< [in] handle of interop memory to be freed + ur_exp_interop_mem_handle_t hInteropMem ///< [in][release] handle of interop memory to be freed ); /////////////////////////////////////////////////////////////////////////////// @@ -8185,7 +8246,7 @@ urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp( - ur_exp_command_buffer_handle_t hCommandBuffer ///< [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ); /////////////////////////////////////////////////////////////////////////////// @@ -8204,7 +8265,7 @@ urCommandBufferRetainExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp( - ur_exp_command_buffer_handle_t hCommandBuffer ///< [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer ///< [in][release] Handle of the command-buffer object. ); /////////////////////////////////////////////////////////////////////////////// @@ -8752,7 +8813,7 @@ urCommandBufferRetainCommandExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( - ur_exp_command_buffer_command_handle_t hCommand ///< [in] Handle of the command-buffer command. + ur_exp_command_buffer_command_handle_t hCommand ///< [in][release] Handle of the command-buffer command. ); /////////////////////////////////////////////////////////////////////////////// @@ -9513,6 +9574,15 @@ typedef struct ur_loader_config_set_code_location_callback_params_t { void **ppUserData; } ur_loader_config_set_code_location_callback_params_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function parameters for urLoaderConfigSetMockCallbacks +/// @details Each entry is a pointer to the parameter passed to the function; +/// allowing the callback the ability to modify the parameter's value +typedef struct ur_loader_config_set_mock_callbacks_params_t { + ur_loader_config_handle_t *phLoaderConfig; + ur_mock_callback_properties_t **ppCallbackProperties; +} ur_loader_config_set_mock_callbacks_params_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urPlatformGet /// @details Each entry is a pointer to the parameter passed to the function; diff --git a/include/ur_print.h b/include/ur_print.h index c8fb41753e..bc724cbdf6 100644 --- a/include/ur_print.h +++ b/include/ur_print.h @@ -98,6 +98,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigInfo(enum ur_loader_confi /// - `buff_size < out_size` UR_APIEXPORT ur_result_t UR_APICALL urPrintCodeLocation(const struct ur_code_location_t params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_callback_override_mode_t enum +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintCallbackOverrideMode(enum ur_callback_override_mode_t value, char *buffer, const size_t buff_size, size_t *out_size); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_mock_callback_properties_t struct +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintMockCallbackProperties(const struct ur_mock_callback_properties_t params, char *buffer, const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_adapter_info_t enum /// @returns @@ -1074,6 +1090,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigEnableLayerParams(const s /// - `buff_size < out_size` UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigSetCodeLocationCallbackParams(const struct ur_loader_config_set_code_location_callback_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print ur_loader_config_set_mock_callbacks_params_t struct +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// - `buff_size < out_size` +UR_APIEXPORT ur_result_t UR_APICALL urPrintLoaderConfigSetMockCallbacksParams(const struct ur_loader_config_set_mock_callbacks_params_t *params, char *buffer, const size_t buff_size, size_t *out_size); + /////////////////////////////////////////////////////////////////////////////// /// @brief Print ur_platform_get_params_t struct /// @returns diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 0cc35e84c5..03701dc965 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -227,6 +227,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct inline std::ostream &operator<<(std::ostream &os, enum ur_device_init_flag_t value); inline std::ostream &operator<<(std::ostream &os, enum ur_loader_config_info_t value); inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_code_location_t params); +inline std::ostream &operator<<(std::ostream &os, enum ur_callback_override_mode_t value); +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_mock_callback_properties_t params); inline std::ostream &operator<<(std::ostream &os, enum ur_adapter_info_t value); inline std::ostream &operator<<(std::ostream &os, enum ur_adapter_backend_t value); inline std::ostream &operator<<(std::ostream &os, enum ur_platform_info_t value); @@ -932,6 +934,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) { case UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE: os << "UR_FUNCTION_KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE"; break; + case UR_FUNCTION_LOADER_CONFIG_SET_MOCK_CALLBACKS: + os << "UR_FUNCTION_LOADER_CONFIG_SET_MOCK_CALLBACKS"; + break; default: os << "unknown enumerator"; break; @@ -1049,6 +1054,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_structure_type_t value case UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC: os << "UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC"; break; + case UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES: + os << "UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES"; + break; case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC: os << "UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC"; break; @@ -1277,6 +1285,11 @@ inline ur_result_t printStruct(std::ostream &os, const void *ptr) { printPtr(os, pstruct); } break; + case UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES: { + const ur_mock_callback_properties_t *pstruct = (const ur_mock_callback_properties_t *)ptr; + printPtr(os, pstruct); + } break; + case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC: { const ur_exp_command_buffer_desc_t *pstruct = (const ur_exp_command_buffer_desc_t *)ptr; printPtr(os, pstruct); @@ -1850,6 +1863,63 @@ inline std::ostream &operator<<(std::ostream &os, const struct ur_code_location_ return os; } /////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_callback_override_mode_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, enum ur_callback_override_mode_t value) { + switch (value) { + case UR_CALLBACK_OVERRIDE_MODE_BEFORE: + os << "UR_CALLBACK_OVERRIDE_MODE_BEFORE"; + break; + case UR_CALLBACK_OVERRIDE_MODE_REPLACE: + os << "UR_CALLBACK_OVERRIDE_MODE_REPLACE"; + break; + case UR_CALLBACK_OVERRIDE_MODE_AFTER: + os << "UR_CALLBACK_OVERRIDE_MODE_AFTER"; + break; + default: + os << "unknown enumerator"; + break; + } + return os; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_mock_callback_properties_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, const struct ur_mock_callback_properties_t params) { + os << "(struct ur_mock_callback_properties_t){"; + + os << ".stype = "; + + os << (params.stype); + + os << ", "; + os << ".pNext = "; + + ur::details::printStruct(os, + (params.pNext)); + + os << ", "; + os << ".name = "; + + ur::details::printPtr(os, + (params.name)); + + os << ", "; + os << ".mode = "; + + os << (params.mode); + + os << ", "; + os << ".pCallback = "; + + os << (params.pCallback); + + os << "}"; + return os; +} +/////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_adapter_info_t type /// @returns /// std::ostream & @@ -10101,6 +10171,26 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct return os; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_loader_config_set_mock_callbacks_params_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_loader_config_set_mock_callbacks_params_t *params) { + + os << ".hLoaderConfig = "; + + ur::details::printPtr(os, + *(params->phLoaderConfig)); + + os << ", "; + os << ".pCallbackProperties = "; + + ur::details::printPtr(os, + *(params->ppCallbackProperties)); + + return os; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_platform_get_params_t type /// @returns @@ -17042,6 +17132,9 @@ inline ur_result_t UR_APICALL printFunctionParams(std::ostream &os, ur_function_ case UR_FUNCTION_LOADER_CONFIG_SET_CODE_LOCATION_CALLBACK: { os << (const struct ur_loader_config_set_code_location_callback_params_t *)params; } break; + case UR_FUNCTION_LOADER_CONFIG_SET_MOCK_CALLBACKS: { + os << (const struct ur_loader_config_set_mock_callbacks_params_t *)params; + } break; case UR_FUNCTION_PLATFORM_GET: { os << (const struct ur_platform_get_params_t *)params; } break; diff --git a/scripts/YaML.md b/scripts/YaML.md index a995b5858c..52d37ba091 100644 --- a/scripts/YaML.md +++ b/scripts/YaML.md @@ -620,11 +620,12 @@ class ur_name_t(Structure): - `out` is used for params that are write-only; if the param is a pointer, then the memory being pointed to is also write-only - `in,out` is used for params that are both read and write; typically this is used for pointers to other data structures that contain both read and write params - `nocheck` is used to specify that no additional validation checks will be generated. - + `desc` may include one the following annotations: {`"[optional]"`, `"[range(start,end)]"`, `"[release]"`, `"[typename(typeVarName)]"`, `"[bounds(offset,size)]"`} + + `desc` may include one the following annotations: {`"[optional]"`, `"[range(start,end)]"`, `"[retain]"`, `"[release]"`, `"[typename(typeVarName)]"`, `"[bounds(offset,size)]"`} - `optional` is used for params that are handles or pointers where it is legal for the value to be `nullptr` - `range` is used for params that are array pointers to specify the valid range that the is valid to read + `start` and `end` must be an ISO-C standard identifier or literal + `start` is inclusive and `end` is exclusive + - `retain` is used for params that are handles or pointers to handles where the function will increment the reference counter associated with the handle(s) - `release` is used for params that are handles or pointers to handles where the function will destroy any backing memory associated with the handle(s) - `typename` is used to denote the type enum for params that are opaque pointers to values of tagged data types. - `bounds` is used for params that are memory objects or USM allocations. It specifies the range within the memory allocation represented by the param that will be accessed by the operation. diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index a81c282070..649a5c100c 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -256,6 +256,38 @@ Currently, UR looks for these adapter libraries: For more information about the usage of mentioned environment variables see `Environment Variables`_ section. +Mocking +--------------------- +A mock UR adapter can be accessed for test purposes by enabling the ``MOCK`` +layer as described below. When the mock layer is enabled, calls to the API will +still be intercepted by other layers (e.g. validation, tracing), but they will +stop short of the loader - the call chain will end in either a generic fallback +behavior defined by the mock layer itself, or a user defined replacement +callback. + +The default fallback behavior for entry points in the mock layer is to simply +return ``UR_RESULT_SUCCESS``. For entry points concerning handles, i.e. those +that create a new handle or modify the reference count of an existing one, a +dummy handle mechanism is used. This means the layer will return generic +handles that track a reference count, and ``Retain``/``Release`` entry points will +function as expected when used with these handles. + +During global setup the behavior of the mock layer can be customized by setting +callbacks via ${x}LoaderConfigSetMockCallbacks. This entry point accepts a ``pNext`` +chain of structs, with each registering a callback with a given entry point in +the API. Callbacks can be registered to be called ``BEFORE`` or ``AFTER`` the +generic implementation, or they can be registered to entirely ``REPLACE`` it. A +given entry point can only have one of each kind of callback associated with +it, multiple structs with the same function/mode combination will override +eachother. + +The callback signature defined by ``${x}_mock_callback_t`` takes a single +``void *`` parameter. When calling a user callback the layer will pack the +entry point's parameters into the appropriate ``_params_t`` struct (e.g. +``ur_adapter_get_params_t``) and pass a pointer to that struct into the +callback. This allows parameters to be accessed and modified. The definitions +for these parameter structs can be found in the main API header. + Layers --------------------- UR comes with a mechanism that allows various API intercept layers to be enabled, either through the API or with an environment variable (see `Environment Variables`_). @@ -278,6 +310,8 @@ Layers currently included with the runtime are as follows: - Enables the XPTI tracing layer, see Tracing_ for more detail. * - UR_LAYER_ASAN \| UR_LAYER_MSAN \| UR_LAYER_TSAN - Enables the device-side sanitizer layer, see Sanitizers_ for more detail. + * - UR_LAYER_MOCK + - Enables adapter mocking for test purposes. Similar behavior to the null adapter except entry points can be overridden or instrumented with callbacks. See Mocking_ for more detail. Environment Variables --------------------- diff --git a/scripts/core/adapter.yml b/scripts/core/adapter.yml index a2331244e1..958d135b78 100644 --- a/scripts/core/adapter.yml +++ b/scripts/core/adapter.yml @@ -56,7 +56,7 @@ params: - type: "$x_adapter_handle_t" name: hAdapter desc: | - [in] Adapter handle to release + [in][release] Adapter handle to release --- #-------------------------------------------------------------------------- type: function desc: "Get a reference to the adapter handle." @@ -70,7 +70,7 @@ params: - type: "$x_adapter_handle_t" name: hAdapter desc: | - [in] Adapter handle to retain + [in][retain] Adapter handle to retain --- #-------------------------------------------------------------------------- type: function desc: "Get the last adapter specific error." diff --git a/scripts/core/context.yml b/scripts/core/context.yml index fef1161fca..69987cab99 100644 --- a/scripts/core/context.yml +++ b/scripts/core/context.yml @@ -80,7 +80,7 @@ params: - type: "$x_context_handle_t" name: hContext desc: | - [in] handle of the context to get a reference of. + [in][retain] handle of the context to get a reference of. --- #-------------------------------------------------------------------------- type: enum desc: "Supported context info" @@ -129,7 +129,7 @@ params: - type: "$x_context_handle_t" name: hContext desc: | - [in] handle of the context to release. + [in][release] handle of the context to release. --- #-------------------------------------------------------------------------- type: function desc: "Retrieves various information about context" diff --git a/scripts/core/device.yml b/scripts/core/device.yml index f1861731d3..232791a675 100644 --- a/scripts/core/device.yml +++ b/scripts/core/device.yml @@ -503,7 +503,7 @@ params: - type: "$x_device_handle_t" name: hDevice desc: | - [in] handle of the device to get a reference of. + [in][retain] handle of the device to get a reference of. --- #-------------------------------------------------------------------------- type: function desc: "Releases the device handle reference indicating end of its usage" @@ -522,7 +522,7 @@ params: - type: "$x_device_handle_t" name: hDevice desc: | - [in] handle of the device to release. + [in][release] handle of the device to release. --- #-------------------------------------------------------------------------- type: enum desc: "Device affinity domain" diff --git a/scripts/core/event.yml b/scripts/core/event.yml index 4e8be75cf4..45bcbf7d40 100644 --- a/scripts/core/event.yml +++ b/scripts/core/event.yml @@ -230,7 +230,7 @@ analogue: params: - type: $x_event_handle_t name: hEvent - desc: "[in] handle of the event object" + desc: "[in][retain] handle of the event object" returns: - $X_RESULT_ERROR_INVALID_EVENT - $X_RESULT_ERROR_OUT_OF_RESOURCES @@ -246,7 +246,7 @@ analogue: params: - type: $x_event_handle_t name: hEvent - desc: "[in] handle of the event object" + desc: "[in][release] handle of the event object" returns: - $X_RESULT_ERROR_INVALID_EVENT - $X_RESULT_ERROR_OUT_OF_RESOURCES diff --git a/scripts/core/exp-bindless-images.yml b/scripts/core/exp-bindless-images.yml index 622e378f0b..7615f2ee53 100644 --- a/scripts/core/exp-bindless-images.yml +++ b/scripts/core/exp-bindless-images.yml @@ -700,7 +700,7 @@ params: desc: "[in] handle of the device object" - type: $x_exp_interop_mem_handle_t name: hInteropMem - desc: "[in] handle of interop memory to be freed" + desc: "[in][release] handle of interop memory to be freed" returns: - $X_RESULT_ERROR_INVALID_CONTEXT - $X_RESULT_ERROR_INVALID_VALUE diff --git a/scripts/core/exp-command-buffer.yml b/scripts/core/exp-command-buffer.yml index 6e276eac88..63ec03974f 100644 --- a/scripts/core/exp-command-buffer.yml +++ b/scripts/core/exp-command-buffer.yml @@ -252,7 +252,7 @@ name: RetainExp params: - type: $x_exp_command_buffer_handle_t name: hCommandBuffer - desc: "[in] Handle of the command-buffer object." + desc: "[in][retain] Handle of the command-buffer object." returns: - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP - $X_RESULT_ERROR_OUT_OF_RESOURCES @@ -265,7 +265,7 @@ name: ReleaseExp params: - type: $x_exp_command_buffer_handle_t name: hCommandBuffer - desc: "[in] Handle of the command-buffer object." + desc: "[in][release] Handle of the command-buffer object." returns: - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP - $X_RESULT_ERROR_OUT_OF_RESOURCES @@ -893,7 +893,7 @@ name: ReleaseCommandExp params: - type: $x_exp_command_buffer_command_handle_t name: hCommand - desc: "[in] Handle of the command-buffer command." + desc: "[in][release] Handle of the command-buffer command." returns: - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP - $X_RESULT_ERROR_OUT_OF_RESOURCES diff --git a/scripts/core/kernel.yml b/scripts/core/kernel.yml index 5446f3bc1d..972a9e34c2 100644 --- a/scripts/core/kernel.yml +++ b/scripts/core/kernel.yml @@ -292,7 +292,7 @@ details: params: - type: $x_kernel_handle_t name: hKernel - desc: "[in] handle for the Kernel to retain" + desc: "[in][retain] handle for the Kernel to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release Kernel." @@ -309,7 +309,7 @@ details: params: - type: $x_kernel_handle_t name: hKernel - desc: "[in] handle for the Kernel to release" + desc: "[in][release] handle for the Kernel to release" --- #-------------------------------------------------------------------------- type: struct desc: "Properties for for $xKernelSetArgPointer." diff --git a/scripts/core/loader.yml b/scripts/core/loader.yml index b5ad1eadec..8a869d3061 100644 --- a/scripts/core/loader.yml +++ b/scripts/core/loader.yml @@ -52,7 +52,7 @@ details: params: - type: $x_loader_config_handle_t name: hLoaderConfig - desc: "[in] loader config handle to retain" + desc: "[in][retain] loader config handle to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release config handle." @@ -67,7 +67,7 @@ details: params: - type: $x_loader_config_handle_t name: hLoaderConfig - desc: "[in] config handle to release" + desc: "[in][release] config handle to release" --- #-------------------------------------------------------------------------- type: enum desc: "Supported loader info" @@ -187,6 +187,59 @@ params: name: pUserData desc: "[in][out][optional] pointer to data to be passed to callback." --- #-------------------------------------------------------------------------- +type: enum +desc: "Callback override mode" +class: $xLoaderConfig +name: $x_callback_override_mode_t +etors: + - name: BEFORE + desc: "Invoke callback before function." + - name: REPLACE + desc: "Invoke callback instead of function." + - name: AFTER + desc: "Invoke callback after function." +--- #-------------------------------------------------------------------------- +type: fptr_typedef +desc: "Callback to replace or instrument generic mock functionality in the mock layer." +name: $x_mock_callback_t +return: $x_result_t +params: + - type: void* + name: pParams + desc: "[in][out] Pointer to the appropriate param struct for the function" +--- #-------------------------------------------------------------------------- +type: struct +desc: "Callback properties" +class: $xLoaderConfig +name: $x_mock_callback_properties_t +base: $x_base_properties_t +members: + - type: const char* + name: name + desc: "[in] Full name of the function to associate callback with." + - type: $x_callback_override_mode_t + name: mode + desc: "[in] Override mode for this callback." + - type: $x_mock_callback_t + name: pCallback + desc: "[in] Callback function pointer." +--- #-------------------------------------------------------------------------- +type: function +desc: "Set a callback to be called before, after or instead of a given entry point" +details: + - "The callback layer will pass the function's parameter struct (e.g. **$x_adapter_get_params_t**) to the $x_mock_callback_t so parameters can be accessed and modified." +class: $xLoaderConfig +loader_only: True +name: "SetMockCallbacks" +decl: static +params: + - type: $x_loader_config_handle_t + name: hLoaderConfig + desc: "[in] Handle to config object the layer will be enabled for." + - type: $x_mock_callback_properties_t* + name: pCallbackProperties + desc: "[in] Pointer to callback properties struct." +--- #-------------------------------------------------------------------------- type: function desc: "Initialize the $OneApi loader" class: $xLoader diff --git a/scripts/core/memory.yml b/scripts/core/memory.yml index c4009bc56e..667d8244c1 100644 --- a/scripts/core/memory.yml +++ b/scripts/core/memory.yml @@ -347,7 +347,7 @@ details: params: - type: $x_mem_handle_t name: hMem - desc: "[in] handle of the memory object to get access" + desc: "[in][retain] handle of the memory object to get access" returns: - $X_RESULT_ERROR_INVALID_MEM_OBJECT - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY @@ -364,7 +364,7 @@ analogue: params: - type: $x_mem_handle_t name: hMem - desc: "[in] handle of the memory object to release" + desc: "[in][release] handle of the memory object to release" returns: - $X_RESULT_ERROR_INVALID_MEM_OBJECT - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/program.yml b/scripts/core/program.yml index 45f7710d68..04b26c32ad 100644 --- a/scripts/core/program.yml +++ b/scripts/core/program.yml @@ -262,7 +262,7 @@ details: params: - type: $x_program_handle_t name: hProgram - desc: "[in] handle for the Program to retain" + desc: "[in][retain] handle for the Program to retain" --- #-------------------------------------------------------------------------- type: function desc: "Release Program." @@ -279,7 +279,7 @@ details: params: - type: $x_program_handle_t name: hProgram - desc: "[in] handle for the Program to release" + desc: "[in][release] handle for the Program to release" --- #-------------------------------------------------------------------------- type: function desc: "Retrieves a device function pointer to a user-defined function." diff --git a/scripts/core/queue.yml b/scripts/core/queue.yml index 816da179ba..3e9bba642b 100644 --- a/scripts/core/queue.yml +++ b/scripts/core/queue.yml @@ -180,7 +180,7 @@ details: params: - type: $x_queue_handle_t name: hQueue - desc: "[in] handle of the queue object to get access" + desc: "[in][retain] handle of the queue object to get access" returns: - $X_RESULT_ERROR_INVALID_QUEUE - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY @@ -200,7 +200,7 @@ details: params: - type: $x_queue_handle_t name: hQueue - desc: "[in] handle of the queue object to release" + desc: "[in][release] handle of the queue object to release" returns: - $X_RESULT_ERROR_INVALID_QUEUE - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 52585ade3a..702c2b31a1 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -589,6 +589,9 @@ etors: - name: KERNEL_GET_SUGGESTED_LOCAL_WORK_SIZE desc: Enumerator for $xKernelGetSuggestedLocalWorkSize value: '225' +- name: LOADER_CONFIG_SET_MOCK_CALLBACKS + desc: Enumerator for $xLoaderConfigSetMockCallbacks + value: '229' --- type: enum desc: Defines structure types @@ -699,3 +702,6 @@ etors: - name: USM_ALLOC_LOCATION_DESC desc: $x_usm_alloc_location_desc_t value: '35' +- name: MOCK_CALLBACK_PROPERTIES + desc: $x_mock_callback_properties_t + value: '38' diff --git a/scripts/core/sampler.yml b/scripts/core/sampler.yml index 7bb33b357e..6459277c6f 100644 --- a/scripts/core/sampler.yml +++ b/scripts/core/sampler.yml @@ -116,7 +116,7 @@ analogue: params: - type: $x_sampler_handle_t name: hSampler - desc: "[in] handle of the sampler object to get access" + desc: "[in][retain] handle of the sampler object to get access" returns: - $X_RESULT_ERROR_INVALID_SAMPLER - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY @@ -132,7 +132,7 @@ analogue: params: - type: $x_sampler_handle_t name: hSampler - desc: "[in] handle of the sampler object to release" + desc: "[in][release] handle of the sampler object to release" returns: - $X_RESULT_ERROR_INVALID_SAMPLER - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY diff --git a/scripts/core/usm.yml b/scripts/core/usm.yml index 7ba75bc865..da5cd8c578 100644 --- a/scripts/core/usm.yml +++ b/scripts/core/usm.yml @@ -433,7 +433,7 @@ ordinal: "0" params: - type: $x_usm_pool_handle_t name: pPool - desc: "[in] pointer to USM memory pool" + desc: "[in][retain] pointer to USM memory pool" returns: - $X_RESULT_ERROR_INVALID_NULL_HANDLE --- #-------------------------------------------------------------------------- @@ -449,7 +449,7 @@ details: params: - type: $x_usm_pool_handle_t name: pPool - desc: "[in] pointer to USM memory pool" + desc: "[in][release] pointer to USM memory pool" returns: - $X_RESULT_ERROR_INVALID_NULL_HANDLE --- #-------------------------------------------------------------------------- diff --git a/scripts/core/virtual_memory.yml b/scripts/core/virtual_memory.yml index ba88d4be2e..5b12e1761e 100644 --- a/scripts/core/virtual_memory.yml +++ b/scripts/core/virtual_memory.yml @@ -292,7 +292,7 @@ name: Retain params: - type: $x_physical_mem_handle_t name: hPhysicalMem - desc: "[in] handle of the physical memory object to retain." + desc: "[in][retain] handle of the physical memory object to retain." --- #-------------------------------------------------------------------------- type: function @@ -302,4 +302,4 @@ name: Release params: - type: $x_physical_mem_handle_t name: hPhysicalMem - desc: "[in] handle of the physical memory object to release." + desc: "[in][release] handle of the physical memory object to release." diff --git a/scripts/generate_code.py b/scripts/generate_code.py index bdaa475a3e..e39f905924 100644 --- a/scripts/generate_code.py +++ b/scripts/generate_code.py @@ -287,6 +287,30 @@ def _mako_tracing_layer_cpp(path, namespace, tags, version, specs, meta): specs=specs, meta=meta) +""" + generates c/c++ files from the specification documents +""" +def _mako_mock_layer_cpp(path, namespace, tags, version, specs, meta): + dstpath = os.path.join(path, "mock") + os.makedirs(dstpath, exist_ok=True) + + template = "mockddi.cpp.mako" + fin = os.path.join(templates_dir, template) + + name = "%s_mockddi"%(namespace) + filename = "%s.cpp"%(name) + fout = os.path.join(dstpath, filename) + + print("Generating %s..."%fout) + return util.makoWrite( + fin, fout, + name=name, + ver=version, + namespace=namespace, + tags=tags, + specs=specs, + meta=meta) + """ generates c/c++ files from the specification documents """ @@ -417,6 +441,11 @@ def generate_layers(path, section, namespace, tags, version, specs, meta): loc += _mako_tracing_layer_cpp(layer_dstpath, namespace, tags, version, specs, meta) print("TRACING Generated %s lines of code.\n"%loc) + loc = 0 + loc += _mako_mock_layer_cpp(layer_dstpath, namespace, tags, version, specs, meta) + print("MOCK Generated %s lines of code.\n"%loc) + + """ Entry-point: generates common utilities for unified_runtime diff --git a/scripts/templates/helper.py b/scripts/templates/helper.py index dc6becfec5..89c3365075 100644 --- a/scripts/templates/helper.py +++ b/scripts/templates/helper.py @@ -371,6 +371,7 @@ class param_traits: RE_OPTIONAL = r".*\[optional\].*" RE_NOCHECK = r".*\[nocheck\].*" RE_RANGE = r".*\[range\((.+),\s*(.+)\)\][\S\s]*" + RE_RETAIN = r".*\[retain\].*" RE_RELEASE = r".*\[release\].*" RE_TYPENAME = r".*\[typename\((.+),\s(.+)\)\].*" RE_TAGGED = r".*\[tagged_by\((.+)\)].*" @@ -460,6 +461,13 @@ def range_end(cls, item): except: return None + @classmethod + def is_retain(cls, item): + try: + return True if re.match(cls.RE_RETAIN, item['desc']) else False + except: + return False + @classmethod def is_release(cls, item): try: @@ -907,6 +915,19 @@ def make_param_lines(namespace, tags, obj, decl=False, meta=None, format=["type" lines = ["void"] return lines +""" +Public: + searches params of function `obj` for a match to the given regex and + returns its full C++ name +""" +def find_param_name(name_re, namespace, tags, obj): + for param in obj['params']: + param_cpp_name = _get_param_name(namespace, tags, param) + print("searching {0} for pattner {1}".format(param_cpp_name, name_re)) + if re.search(name_re, param_cpp_name): + return param_cpp_name + return UNDEFINED + """ Public: returns a list of strings for the description @@ -1465,7 +1486,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname) fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname) - if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): + if param_traits.is_retain(item) or param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): if type_traits.is_class_handle(item['type'], meta): if param_traits.is_range(item): range_start = param_traits.range_start(item) @@ -1475,6 +1496,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): 'type': tname, 'obj': obj_name, 'factory': fty_name, + 'retain': param_traits.is_retain(item), 'release': param_traits.is_release(item), 'range': (range_start, range_end) }) @@ -1484,6 +1506,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): 'type': tname, 'obj': obj_name, 'factory': fty_name, + 'retain': param_traits.is_retain(item), 'release': param_traits.is_release(item), 'optional': param_traits.is_optional(item) }) @@ -1521,6 +1544,7 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta): epilogue.append({ 'name': name, 'obj': obj_name, + 'retain': False, 'release': False, 'typename': typename, 'size': prop_size, diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index bbc7c7c7d0..840be4433a 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -279,6 +279,8 @@ namespace ur_loader %if item['release']: // release loader handle ${item['factory']}.release( ${item['name']} ); + %elif item['retain']: + // TODO: do we need to ref count the loader handles? %elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { diff --git a/scripts/templates/mockddi.cpp.mako b/scripts/templates/mockddi.cpp.mako new file mode 100644 index 0000000000..3932d9ae22 --- /dev/null +++ b/scripts/templates/mockddi.cpp.mako @@ -0,0 +1,191 @@ +<%! +import re +from templates import helper as th +%><% + n=namespace + N=n.upper() + + x=tags['$x'] + X=x.upper() + + handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags) +%>/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ${name}.cpp + * + */ +#include "${x}_mock_layer.hpp" + +namespace ur_mock_layer +{ + %for obj in th.get_adapter_functions(specs): + <% + func_name=th.make_func_name(n, tags, obj) + %> + /////////////////////////////////////////////////////////////////////////////// + /// @brief Intercept function for ${th.make_func_name(n, tags, obj)} + %if 'condition' in obj: + #if ${th.subt(n, tags, obj['condition'])} + %endif + __${x}dlllocal ${x}_result_t ${X}_APICALL + ${func_name}( + %for line in th.make_param_lines(n, tags, obj): + ${line} + %endfor + ) + { + auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}; + + if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) { + return ${X}_RESULT_ERROR_UNINITIALIZED; + } + + ${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} }; + + ${x}_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("${func_name}")); + if(beforeCallback) { + result = beforeCallback( ¶ms ); + if(result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("${func_name}")); + if(replaceCallback) { + result = replaceCallback( ¶ms ); + } else { + <% + # We can use the loader epilogue to know when we should be creating mock handles + epilogue = th.get_loader_epilogue(specs, n, tags, obj, meta) + %> + %if 'NativeHandle' in func_name: + <% func_class = th.subt(namespace, tags, obj['class'], False, True) %> + %if 'CreateWith' in func_name: + *ph${func_class} = reinterpret_cast(hNative${func_class}); + retainDummyHandle(*ph${func_class}); + %else: + *phNative${func_class} = reinterpret_cast(h${func_class}); + %endif + %else: + %if func_name == 'urAdapterGet' or func_name == 'urDeviceGet' or func_name == 'urPlatformGet': + <% + num_param = th.find_param_name(".*pNum.*", n, tags, obj) + %> + if(${num_param}) { + *${num_param} = 1; + } + %endif + %for item in epilogue: + %if item['release']: + releaseDummyHandle(${item['name']}); + %elif item['retain']: + retainDummyHandle(${item['name']}); + %elif 'type' in item: + %if 'range' in item or ('optional' in item and item['optional']): + // optional output handle + if(${item['name']}) { + *${item['name']} = createDummyHandle<${item['type']}>(); + } + %else: + *${item['name']} = createDummyHandle<${item['type']}>(); + %endif + %endif + %endfor + %endif + result = UR_RESULT_SUCCESS; + } + if(result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("${func_name}")); + if(afterCallback) { + return afterCallback( ¶ms ); + } + + return result; + } + %if 'condition' in obj: + #endif // ${th.subt(n, tags, obj['condition'])} + %endif + + %endfor + %for tbl in th.get_pfntables(specs, meta, n, tags): + /////////////////////////////////////////////////////////////////////////////// + /// @brief Exported function for filling application's ${tbl['name']} table + /// with current process' addresses + /// + /// @returns + /// - ::${X}_RESULT_SUCCESS + /// - ::${X}_RESULT_ERROR_INVALID_NULL_POINTER + /// - ::${X}_RESULT_ERROR_UNSUPPORTED_VERSION + ${X}_DLLEXPORT ${x}_result_t ${X}_APICALL + ${tbl['export']['name']}( + %for line in th.make_param_lines(n, tags, tbl['export']): + ${line} + %endfor + ) + { + auto& dditable = ur_mock_layer::context.${n}DdiTable.${tbl['name']}; + + if( nullptr == pDdiTable ) + return ${X}_RESULT_ERROR_INVALID_NULL_POINTER; + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > UR_MINOR_VERSION(version)) + return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION; + + ${x}_result_t result = ${X}_RESULT_SUCCESS; + + %for obj in tbl['functions']: + %if 'condition' in obj: + #if ${th.subt(n, tags, obj['condition'])} + %endif + dditable.${th.append_ws(th.make_pfn_name(n, tags, obj), 43)} = pDdiTable->${th.make_pfn_name(n, tags, obj)}; + pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = ur_mock_layer::${th.make_func_name(n, tags, obj)}; + %if 'condition' in obj: + #else + dditable.${th.append_ws(th.make_pfn_name(n, tags, obj), 43)} = nullptr; + pDdiTable->${th.append_ws(th.make_pfn_name(n, tags, obj), 41)} = nullptr; + #endif + %endif + + %endfor + return result; + } + + %endfor + ${x}_result_t + context_t::init(ur_dditable_t *dditable, + const std::set &enabledLayerNames, + codeloc_data, api_callbacks apiCallbacks) { + ${x}_result_t result = ${X}_RESULT_SUCCESS; + + if(!enabledLayerNames.count(name)) { + return result; + } + + ur_mock_layer::context.apiCallbacks = apiCallbacks; + + %for tbl in th.get_pfntables(specs, meta, n, tags): + if ( ${X}_RESULT_SUCCESS == result ) + { + result = ur_mock_layer::${tbl['export']['name']}( ${X}_API_VERSION_CURRENT, &dditable->${tbl['name']} ); + } + + %endfor + return result; + } + +} // namespace ur_mock_layer diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako index f503d4073c..9eb7ec6194 100644 --- a/scripts/templates/nullddi.cpp.mako +++ b/scripts/templates/nullddi.cpp.mako @@ -67,7 +67,7 @@ namespace driver %elif 'range' in item: for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i ) ${item['name']}[ i ] = reinterpret_cast<${item['type']}>( d_context.get() ); - %elif not item['release']: + %elif not item['release'] and not item['retain']: %if item['optional']: if( nullptr != ${item['name']} ) *${item['name']} = reinterpret_cast<${item['type']}>( d_context.get() ); %else: diff --git a/scripts/templates/trcddi.cpp.mako b/scripts/templates/trcddi.cpp.mako index 6f6579d5ac..b39c1815b8 100644 --- a/scripts/templates/trcddi.cpp.mako +++ b/scripts/templates/trcddi.cpp.mako @@ -111,7 +111,7 @@ namespace ur_tracing_layer ${x}_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) { + codeloc_data codelocData, api_callbacks) { ${x}_result_t result = ${X}_RESULT_SUCCESS; if(!enabledLayerNames.count(name)) { diff --git a/scripts/templates/valddi.cpp.mako b/scripts/templates/valddi.cpp.mako index c8905a7e8b..19cdb48faf 100644 --- a/scripts/templates/valddi.cpp.mako +++ b/scripts/templates/valddi.cpp.mako @@ -173,7 +173,7 @@ namespace ur_validation_layer ${x}_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data) { + codeloc_data, api_callbacks) { ${x}_result_t result = ${X}_RESULT_SUCCESS; if (enabledLayerNames.count(nameFullValidation)) { diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/null/ur_nullddi.cpp index a713a385a7..b4a9733a20 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/null/ur_nullddi.cpp @@ -49,7 +49,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRelease __urdlllocal ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -69,7 +69,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -428,7 +428,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -448,7 +448,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urDeviceRelease __urdlllocal ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -641,7 +642,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -661,7 +662,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urContextRelease __urdlllocal ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -864,7 +866,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -884,7 +887,8 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRelease __urdlllocal ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1124,7 +1128,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1145,7 +1149,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( /// @brief Intercept function for urSamplerRelease __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1446,7 +1450,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1466,7 +1470,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRelease __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1759,7 +1763,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1780,7 +1784,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( /// @brief Intercept function for urPhysicalMemRelease __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1932,7 +1936,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -1952,7 +1957,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease __urdlllocal ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -2399,7 +2405,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -2419,7 +2425,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRelease __urdlllocal ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -2744,7 +2751,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -2764,7 +2772,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRelease __urdlllocal ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -2983,7 +2992,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -3003,7 +3012,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRelease __urdlllocal ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -4585,7 +4594,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -4766,7 +4775,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -4787,7 +4796,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( /// @brief Intercept function for urCommandBufferReleaseExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) try { ur_result_t result = UR_RESULT_SUCCESS; @@ -5357,7 +5366,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// @brief Intercept function for urCommandBufferReleaseCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) try { ur_result_t result = UR_RESULT_SUCCESS; diff --git a/source/common/stype_map_helpers.def b/source/common/stype_map_helpers.def index 0c3e5b1cc1..240d8825c5 100644 --- a/source/common/stype_map_helpers.def +++ b/source/common/stype_map_helpers.def @@ -72,6 +72,8 @@ struct stype_map : stype_map_impl struct stype_map : stype_map_impl {}; template <> +struct stype_map : stype_map_impl {}; +template <> struct stype_map : stype_map_impl {}; template <> struct stype_map : stype_map_impl {}; diff --git a/source/loader/CMakeLists.txt b/source/loader/CMakeLists.txt index 075d9909b0..6bf53f45d0 100644 --- a/source/loader/CMakeLists.txt +++ b/source/loader/CMakeLists.txt @@ -107,9 +107,12 @@ target_sources(ur_loader ${CMAKE_CURRENT_SOURCE_DIR}/ur_lib.hpp ${CMAKE_CURRENT_SOURCE_DIR}/ur_lib.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ur_codeloc.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/ur_callbacks.hpp ${CMAKE_CURRENT_SOURCE_DIR}/ur_print.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/validation/ur_valddi.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/validation/ur_validation_layer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/mock/ur_mock_layer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/mock/ur_mockddi.cpp ) if(UR_ENABLE_TRACING) @@ -159,7 +162,6 @@ if(UR_ENABLE_SANITIZER) ) endif() - # link validation backtrace dependencies if(UNIX) find_package(Libbacktrace) diff --git a/source/loader/layers/mock/ur_mock_layer.cpp b/source/loader/layers/mock/ur_mock_layer.cpp new file mode 100644 index 0000000000..6305bc3a6e --- /dev/null +++ b/source/loader/layers/mock/ur_mock_layer.cpp @@ -0,0 +1,16 @@ +/* + * + * Copyright (C) 2023 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_validation_layer.cpp + * + */ +#include "ur_mock_layer.hpp" + +namespace ur_mock_layer { +context_t context; +}; diff --git a/source/loader/layers/mock/ur_mock_layer.hpp b/source/loader/layers/mock/ur_mock_layer.hpp new file mode 100644 index 0000000000..f0ed45d86e --- /dev/null +++ b/source/loader/layers/mock/ur_mock_layer.hpp @@ -0,0 +1,79 @@ +/* + * + * Copyright (C) 2023-2024 Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_layer.h + * + */ +#pragma once +#include "ur_ddi.h" +#include "ur_proxy_layer.hpp" +#include "ur_util.hpp" + +#include + +namespace ur_mock_layer { + +struct dummy_handle_t_ { + dummy_handle_t_(size_t dataSize = 0) + : storage(dataSize), data(storage.data()) {} + dummy_handle_t_(unsigned char *data) : data(data) {} + std::atomic refCounter = 1; + std::vector storage; + unsigned char *data = nullptr; +}; + +using dummy_handle_t = dummy_handle_t_ *; + +// Allocates a dummy handle of type T with support of reference counting. +// Takes optional 'Size' parameter which can be used to allocate additional +// memory. The handle has to be deallocated using 'releaseDummyHandle'. +template inline T createDummyHandle(size_t Size = 0) { + dummy_handle_t DummyHandlePtr = new dummy_handle_t_(Size); + return reinterpret_cast(DummyHandlePtr); +} + +// Decrement reference counter for the handle and deallocates it if the +// reference counter becomes zero +template inline void releaseDummyHandle(T Handle) { + auto DummyHandlePtr = reinterpret_cast(Handle); + const size_t NewValue = --DummyHandlePtr->refCounter; + if (NewValue == 0) { + delete DummyHandlePtr; + } +} + +// Increment reference counter for the handle +template inline void retainDummyHandle(T Handle) { + auto DummyHandlePtr = reinterpret_cast(Handle); + ++DummyHandlePtr->refCounter; +} + +/////////////////////////////////////////////////////////////////////////////// +class __urdlllocal context_t : public proxy_layer_context_t { + public: + ur_dditable_t urDdiTable = {}; + + context_t() {} + ~context_t() {} + + bool isAvailable() const override { return true; } + std::vector getNames() const override { return {name}; } + ur_result_t init(ur_dditable_t *dditable, + const std::set &enabledLayerNames, + codeloc_data codelocData, + api_callbacks apiCallbacks) override; + ur_result_t tearDown() override { return UR_RESULT_SUCCESS; } + + api_callbacks apiCallbacks; + + private: + const std::string name = "UR_LAYER_MOCK"; +}; + +extern context_t context; +}; // namespace ur_mock_layer diff --git a/source/loader/layers/mock/ur_mockddi.cpp b/source/loader/layers/mock/ur_mockddi.cpp new file mode 100644 index 0000000000..154d2fa5f2 --- /dev/null +++ b/source/loader/layers/mock/ur_mockddi.cpp @@ -0,0 +1,12080 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_mockddi.cpp + * + */ +#include "ur_mock_layer.hpp" + +namespace ur_mock_layer { + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterGet +__urdlllocal ur_result_t UR_APICALL urAdapterGet( + uint32_t + NumEntries, ///< [in] the number of adapters to be added to phAdapters. + ///< If phAdapters is not NULL, then NumEntries should be greater than + ///< zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, + ///< will be returned. + ur_adapter_handle_t * + phAdapters, ///< [out][optional][range(0, NumEntries)] array of handle of adapters. + ///< If NumEntries is less than the number of adapters available, then + ///< ::urAdapterGet shall only retrieve that number of platforms. + uint32_t * + pNumAdapters ///< [out][optional] returns the total number of adapters available. +) { + auto pfnAdapterGet = context.urDdiTable.Global.pfnAdapterGet; + + if (nullptr == pfnAdapterGet) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_adapter_get_params_t params = {&NumEntries, &phAdapters, &pNumAdapters}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urAdapterGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urAdapterGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + if (pNumAdapters) { + *pNumAdapters = 1; + } + // optional output handle + if (phAdapters) { + *phAdapters = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urAdapterGet")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterRelease +__urdlllocal ur_result_t UR_APICALL urAdapterRelease( + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release +) { + auto pfnAdapterRelease = context.urDdiTable.Global.pfnAdapterRelease; + + if (nullptr == pfnAdapterRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_adapter_release_params_t params = {&hAdapter}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urAdapterRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urAdapterRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hAdapter); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urAdapterRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterRetain +__urdlllocal ur_result_t UR_APICALL urAdapterRetain( + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain +) { + auto pfnAdapterRetain = context.urDdiTable.Global.pfnAdapterRetain; + + if (nullptr == pfnAdapterRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_adapter_retain_params_t params = {&hAdapter}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urAdapterRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urAdapterRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hAdapter); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urAdapterRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterGetLastError +__urdlllocal ur_result_t UR_APICALL urAdapterGetLastError( + ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter instance + const char ** + ppMessage, ///< [out] pointer to a C string where the adapter specific error message + ///< will be stored. + int32_t * + pError ///< [out] pointer to an integer where the adapter specific error code will + ///< be stored. +) { + auto pfnAdapterGetLastError = + context.urDdiTable.Global.pfnAdapterGetLastError; + + if (nullptr == pfnAdapterGetLastError) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_adapter_get_last_error_params_t params = {&hAdapter, &ppMessage, + &pError}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urAdapterGetLastError")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urAdapterGetLastError")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urAdapterGetLastError")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterGetInfo +__urdlllocal ur_result_t UR_APICALL urAdapterGetInfo( + ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter + ur_adapter_info_t propName, ///< [in] type of the info to retrieve + size_t propSize, ///< [in] the number of bytes pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If Size is not equal to or greater to the real number of bytes needed + ///< to return the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is + ///< returned and pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual number of bytes being queried by pPropValue. +) { + auto pfnAdapterGetInfo = context.urDdiTable.Global.pfnAdapterGetInfo; + + if (nullptr == pfnAdapterGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_adapter_get_info_params_t params = {&hAdapter, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urAdapterGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urAdapterGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urAdapterGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformGet +__urdlllocal ur_result_t UR_APICALL urPlatformGet( + ur_adapter_handle_t * + phAdapters, ///< [in][range(0, NumAdapters)] array of adapters to query for platforms. + uint32_t NumAdapters, ///< [in] number of adapters pointed to by phAdapters + uint32_t + NumEntries, ///< [in] the number of platforms to be added to phPlatforms. + ///< If phPlatforms is not NULL, then NumEntries should be greater than + ///< zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, + ///< will be returned. + ur_platform_handle_t * + phPlatforms, ///< [out][optional][range(0, NumEntries)] array of handle of platforms. + ///< If NumEntries is less than the number of platforms available, then + ///< ::urPlatformGet shall only retrieve that number of platforms. + uint32_t * + pNumPlatforms ///< [out][optional] returns the total number of platforms available. +) { + auto pfnGet = context.urDdiTable.Platform.pfnGet; + + if (nullptr == pfnGet) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_get_params_t params = {&phAdapters, &NumAdapters, &NumEntries, + &phPlatforms, &pNumPlatforms}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPlatformGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPlatformGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + if (pNumPlatforms) { + *pNumPlatforms = 1; + } + // optional output handle + if (phPlatforms) { + *phPlatforms = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPlatformGet")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformGetInfo +__urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform + ur_platform_info_t propName, ///< [in] type of the info to retrieve + size_t propSize, ///< [in] the number of bytes pointed to by pPlatformInfo. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If Size is not equal to or greater to the real number of bytes needed + ///< to return the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is + ///< returned and pPlatformInfo is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual number of bytes being queried by pPlatformInfo. +) { + auto pfnGetInfo = context.urDdiTable.Platform.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_get_info_params_t params = {&hPlatform, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPlatformGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPlatformGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPlatformGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformGetApiVersion +__urdlllocal ur_result_t UR_APICALL urPlatformGetApiVersion( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform + ur_api_version_t *pVersion ///< [out] api version +) { + auto pfnGetApiVersion = context.urDdiTable.Platform.pfnGetApiVersion; + + if (nullptr == pfnGetApiVersion) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_get_api_version_params_t params = {&hPlatform, &pVersion}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPlatformGetApiVersion")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPlatformGetApiVersion")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPlatformGetApiVersion")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform. + ur_native_handle_t * + phNativePlatform ///< [out] a pointer to the native handle of the platform. +) { + auto pfnGetNativeHandle = context.urDdiTable.Platform.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_get_native_handle_params_t params = {&hPlatform, + &phNativePlatform}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPlatformGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPlatformGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativePlatform = reinterpret_cast(hPlatform); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPlatformGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( + ur_native_handle_t + hNativePlatform, ///< [in][nocheck] the native handle of the platform. + const ur_platform_native_properties_t * + pProperties, ///< [in][optional] pointer to native platform properties struct. + ur_platform_handle_t * + phPlatform ///< [out] pointer to the handle of the platform object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Platform.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_create_with_native_handle_params_t params = { + &hNativePlatform, &pProperties, &phPlatform}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urPlatformCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urPlatformCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phPlatform = reinterpret_cast(hNativePlatform); + retainDummyHandle(*phPlatform); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urPlatformCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPlatformGetBackendOption +__urdlllocal ur_result_t UR_APICALL urPlatformGetBackendOption( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance. + const char + *pFrontendOption, ///< [in] string containing the frontend option. + const char ** + ppPlatformOption ///< [out] returns the correct platform specific compiler option based on + ///< the frontend option. +) { + auto pfnGetBackendOption = context.urDdiTable.Platform.pfnGetBackendOption; + + if (nullptr == pfnGetBackendOption) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_platform_get_backend_option_params_t params = { + &hPlatform, &pFrontendOption, &ppPlatformOption}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPlatformGetBackendOption")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urPlatformGetBackendOption")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPlatformGetBackendOption")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceGet +__urdlllocal ur_result_t UR_APICALL urDeviceGet( + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + ur_device_type_t DeviceType, ///< [in] the type of the devices. + uint32_t + NumEntries, ///< [in] the number of devices to be added to phDevices. + ///< If phDevices is not NULL, then NumEntries should be greater than zero. + ///< Otherwise ::UR_RESULT_ERROR_INVALID_SIZE + ///< will be returned. + ur_device_handle_t * + phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices. + ///< If NumEntries is less than the number of devices available, then + ///< platform shall only retrieve that number of devices. + uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices. + ///< pNumDevices will be updated with the total number of devices available. +) { + auto pfnGet = context.urDdiTable.Device.pfnGet; + + if (nullptr == pfnGet) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_get_params_t params = {&hPlatform, &DeviceType, &NumEntries, + &phDevices, &pNumDevices}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceGet")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceGet")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + if (pNumDevices) { + *pNumDevices = 1; + } + // optional output handle + if (phDevices) { + *phDevices = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceGet")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceGetInfo +__urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( + ur_device_handle_t hDevice, ///< [in] handle of the device instance + ur_device_info_t propName, ///< [in] type of the info to retrieve + size_t propSize, ///< [in] the number of bytes pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If propSize is not equal to or greater than the real number of bytes + ///< needed to return the info + ///< then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnGetInfo = context.urDdiTable.Device.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_get_info_params_t params = {&hDevice, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceRetain +__urdlllocal ur_result_t UR_APICALL urDeviceRetain( + ur_device_handle_t + hDevice ///< [in][retain] handle of the device to get a reference of. +) { + auto pfnRetain = context.urDdiTable.Device.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_retain_params_t params = {&hDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hDevice); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceRelease +__urdlllocal ur_result_t UR_APICALL urDeviceRelease( + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. +) { + auto pfnRelease = context.urDdiTable.Device.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_release_params_t params = {&hDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hDevice); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDevicePartition +__urdlllocal ur_result_t UR_APICALL urDevicePartition( + ur_device_handle_t hDevice, ///< [in] handle of the device to partition. + const ur_device_partition_properties_t + *pProperties, ///< [in] Device partition properties. + uint32_t NumDevices, ///< [in] the number of sub-devices. + ur_device_handle_t * + phSubDevices, ///< [out][optional][range(0, NumDevices)] array of handle of devices. + ///< If NumDevices is less than the number of sub-devices available, then + ///< the function shall only retrieve that number of sub-devices. + uint32_t * + pNumDevicesRet ///< [out][optional] pointer to the number of sub-devices the device can be + ///< partitioned into according to the partitioning property. +) { + auto pfnPartition = context.urDdiTable.Device.pfnPartition; + + if (nullptr == pfnPartition) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_partition_params_t params = {&hDevice, &pProperties, &NumDevices, + &phSubDevices, &pNumDevicesRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDevicePartition")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDevicePartition")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phSubDevices) { + *phSubDevices = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDevicePartition")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceSelectBinary +__urdlllocal ur_result_t UR_APICALL urDeviceSelectBinary( + ur_device_handle_t + hDevice, ///< [in] handle of the device to select binary for. + const ur_device_binary_t + *pBinaries, ///< [in] the array of binaries to select from. + uint32_t NumBinaries, ///< [in] the number of binaries passed in ppBinaries. + ///< Must greater than or equal to zero otherwise + ///< ::UR_RESULT_ERROR_INVALID_VALUE is returned. + uint32_t * + pSelectedBinary ///< [out] the index of the selected binary in the input array of binaries. + ///< If a suitable binary was not found the function returns ::UR_RESULT_ERROR_INVALID_BINARY. +) { + auto pfnSelectBinary = context.urDdiTable.Device.pfnSelectBinary; + + if (nullptr == pfnSelectBinary) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_select_binary_params_t params = {&hDevice, &pBinaries, + &NumBinaries, &pSelectedBinary}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceSelectBinary")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceSelectBinary")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceSelectBinary")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( + ur_device_handle_t hDevice, ///< [in] handle of the device. + ur_native_handle_t + *phNativeDevice ///< [out] a pointer to the native handle of the device. +) { + auto pfnGetNativeHandle = context.urDdiTable.Device.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_get_native_handle_params_t params = {&hDevice, &phNativeDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urDeviceGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urDeviceGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeDevice = reinterpret_cast(hDevice); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( + ur_native_handle_t + hNativeDevice, ///< [in][nocheck] the native handle of the device. + ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance + const ur_device_native_properties_t * + pProperties, ///< [in][optional] pointer to native device properties struct. + ur_device_handle_t + *phDevice ///< [out] pointer to the handle of the device object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Device.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_create_with_native_handle_params_t params = { + &hNativeDevice, &hPlatform, &pProperties, &phDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urDeviceCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urDeviceCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phDevice = reinterpret_cast(hNativeDevice); + retainDummyHandle(*phDevice); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urDeviceCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urDeviceGetGlobalTimestamps +__urdlllocal ur_result_t UR_APICALL urDeviceGetGlobalTimestamps( + ur_device_handle_t hDevice, ///< [in] handle of the device instance + uint64_t * + pDeviceTimestamp, ///< [out][optional] pointer to the Device's global timestamp that + ///< correlates with the Host's global timestamp value + uint64_t * + pHostTimestamp ///< [out][optional] pointer to the Host's global timestamp that + ///< correlates with the Device's global timestamp value +) { + auto pfnGetGlobalTimestamps = + context.urDdiTable.Device.pfnGetGlobalTimestamps; + + if (nullptr == pfnGetGlobalTimestamps) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_device_get_global_timestamps_params_t params = { + &hDevice, &pDeviceTimestamp, &pHostTimestamp}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urDeviceGetGlobalTimestamps")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urDeviceGetGlobalTimestamps")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urDeviceGetGlobalTimestamps")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextCreate +__urdlllocal ur_result_t UR_APICALL urContextCreate( + uint32_t DeviceCount, ///< [in] the number of devices given in phDevices + const ur_device_handle_t + *phDevices, ///< [in][range(0, DeviceCount)] array of handle of devices. + const ur_context_properties_t * + pProperties, ///< [in][optional] pointer to context creation properties. + ur_context_handle_t + *phContext ///< [out] pointer to handle of context object created +) { + auto pfnCreate = context.urDdiTable.Context.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_create_params_t params = {&DeviceCount, &phDevices, &pProperties, + &phContext}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urContextCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urContextCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phContext = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextRetain +__urdlllocal ur_result_t UR_APICALL urContextRetain( + ur_context_handle_t + hContext ///< [in][retain] handle of the context to get a reference of. +) { + auto pfnRetain = context.urDdiTable.Context.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_retain_params_t params = {&hContext}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urContextRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urContextRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hContext); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextRelease +__urdlllocal ur_result_t UR_APICALL urContextRelease( + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. +) { + auto pfnRelease = context.urDdiTable.Context.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_release_params_t params = {&hContext}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urContextRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urContextRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hContext); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextGetInfo +__urdlllocal ur_result_t UR_APICALL urContextGetInfo( + ur_context_handle_t hContext, ///< [in] handle of the context + ur_context_info_t propName, ///< [in] type of the info to retrieve + size_t + propSize, ///< [in] the number of bytes of memory pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< if propSize is not equal to or greater than the real number of bytes + ///< needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnGetInfo = context.urDdiTable.Context.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_get_info_params_t params = {&hContext, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urContextGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urContextGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( + ur_context_handle_t hContext, ///< [in] handle of the context. + ur_native_handle_t * + phNativeContext ///< [out] a pointer to the native handle of the context. +) { + auto pfnGetNativeHandle = context.urDdiTable.Context.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_get_native_handle_params_t params = {&hContext, + &phNativeContext}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urContextGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urContextGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeContext = reinterpret_cast(hContext); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( + ur_native_handle_t + hNativeContext, ///< [in][nocheck] the native handle of the context. + uint32_t numDevices, ///< [in] number of devices associated with the context + const ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context + const ur_context_native_properties_t * + pProperties, ///< [in][optional] pointer to native context properties struct + ur_context_handle_t * + phContext ///< [out] pointer to the handle of the context object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Context.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_create_with_native_handle_params_t params = { + &hNativeContext, &numDevices, &phDevices, &pProperties, &phContext}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urContextCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urContextCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phContext = reinterpret_cast(hNativeContext); + retainDummyHandle(*phContext); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urContextCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextSetExtendedDeleter +__urdlllocal ur_result_t UR_APICALL urContextSetExtendedDeleter( + ur_context_handle_t hContext, ///< [in] handle of the context. + ur_context_extended_deleter_t + pfnDeleter, ///< [in] Function pointer to extended deleter. + void * + pUserData ///< [in][out][optional] pointer to data to be passed to callback. +) { + auto pfnSetExtendedDeleter = + context.urDdiTable.Context.pfnSetExtendedDeleter; + + if (nullptr == pfnSetExtendedDeleter) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_context_set_extended_deleter_params_t params = {&hContext, &pfnDeleter, + &pUserData}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urContextSetExtendedDeleter")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urContextSetExtendedDeleter")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urContextSetExtendedDeleter")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemImageCreate +__urdlllocal ur_result_t UR_APICALL urMemImageCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_mem_flags_t flags, ///< [in] allocation and usage information flags + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + void *pHost, ///< [in][optional] pointer to the buffer data + ur_mem_handle_t *phMem ///< [out] pointer to handle of image object created +) { + auto pfnImageCreate = context.urDdiTable.Mem.pfnImageCreate; + + if (nullptr == pfnImageCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_image_create_params_t params = {&hContext, &flags, &pImageFormat, + &pImageDesc, &pHost, &phMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemImageCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemImageCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemImageCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemBufferCreate +__urdlllocal ur_result_t UR_APICALL urMemBufferCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_mem_flags_t flags, ///< [in] allocation and usage information flags + size_t size, ///< [in] size in bytes of the memory object to be allocated + const ur_buffer_properties_t + *pProperties, ///< [in][optional] pointer to buffer creation properties + ur_mem_handle_t + *phBuffer ///< [out] pointer to handle of the memory buffer created +) { + auto pfnBufferCreate = context.urDdiTable.Mem.pfnBufferCreate; + + if (nullptr == pfnBufferCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_buffer_create_params_t params = {&hContext, &flags, &size, + &pProperties, &phBuffer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemBufferCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemBufferCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phBuffer = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemBufferCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemRetain +__urdlllocal ur_result_t UR_APICALL urMemRetain( + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access +) { + auto pfnRetain = context.urDdiTable.Mem.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_retain_params_t params = {&hMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemRelease +__urdlllocal ur_result_t UR_APICALL urMemRelease( + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release +) { + auto pfnRelease = context.urDdiTable.Mem.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_release_params_t params = {&hMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemBufferPartition +__urdlllocal ur_result_t UR_APICALL urMemBufferPartition( + ur_mem_handle_t + hBuffer, ///< [in] handle of the buffer object to allocate from + ur_mem_flags_t flags, ///< [in] allocation and usage information flags + ur_buffer_create_type_t bufferCreateType, ///< [in] buffer creation type + const ur_buffer_region_t + *pRegion, ///< [in] pointer to buffer create region information + ur_mem_handle_t + *phMem ///< [out] pointer to the handle of sub buffer created +) { + auto pfnBufferPartition = context.urDdiTable.Mem.pfnBufferPartition; + + if (nullptr == pfnBufferPartition) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_buffer_partition_params_t params = { + &hBuffer, &flags, &bufferCreateType, &pRegion, &phMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemBufferPartition")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemBufferPartition")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemBufferPartition")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( + ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. + ur_native_handle_t + *phNativeMem ///< [out] a pointer to the native handle of the mem. +) { + auto pfnGetNativeHandle = context.urDdiTable.Mem.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_get_native_handle_params_t params = {&hMem, &hDevice, &phNativeMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeMem = reinterpret_cast(hMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemBufferCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( + ur_native_handle_t + hNativeMem, ///< [in][nocheck] the native handle to the memory. + ur_context_handle_t hContext, ///< [in] handle of the context object. + const ur_mem_native_properties_t * + pProperties, ///< [in][optional] pointer to native memory creation properties. + ur_mem_handle_t + *phMem ///< [out] pointer to handle of buffer memory object created. +) { + auto pfnBufferCreateWithNativeHandle = + context.urDdiTable.Mem.pfnBufferCreateWithNativeHandle; + + if (nullptr == pfnBufferCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_buffer_create_with_native_handle_params_t params = { + &hNativeMem, &hContext, &pProperties, &phMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urMemBufferCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urMemBufferCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phMem = reinterpret_cast(hNativeMem); + retainDummyHandle(*phMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urMemBufferCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemImageCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( + ur_native_handle_t + hNativeMem, ///< [in][nocheck] the native handle to the memory. + ur_context_handle_t hContext, ///< [in] handle of the context object. + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification. + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description. + const ur_mem_native_properties_t * + pProperties, ///< [in][optional] pointer to native memory creation properties. + ur_mem_handle_t + *phMem ///< [out] pointer to handle of image memory object created. +) { + auto pfnImageCreateWithNativeHandle = + context.urDdiTable.Mem.pfnImageCreateWithNativeHandle; + + if (nullptr == pfnImageCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_image_create_with_native_handle_params_t params = { + &hNativeMem, &hContext, &pImageFormat, + &pImageDesc, &pProperties, &phMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urMemImageCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urMemImageCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phMem = reinterpret_cast(hNativeMem); + retainDummyHandle(*phMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urMemImageCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemGetInfo +__urdlllocal ur_result_t UR_APICALL urMemGetInfo( + ur_mem_handle_t + hMemory, ///< [in] handle to the memory object being queried. + ur_mem_info_t propName, ///< [in] type of the info to retrieve. + size_t + propSize, ///< [in] the number of bytes of memory pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If propSize is less than the real number of bytes needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnGetInfo = context.urDdiTable.Mem.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_get_info_params_t params = {&hMemory, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemImageGetInfo +__urdlllocal ur_result_t UR_APICALL urMemImageGetInfo( + ur_mem_handle_t hMemory, ///< [in] handle to the image object being queried. + ur_image_info_t propName, ///< [in] type of image info to retrieve. + size_t + propSize, ///< [in] the number of bytes of memory pointer to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If propSize is less than the real number of bytes needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnImageGetInfo = context.urDdiTable.Mem.pfnImageGetInfo; + + if (nullptr == pfnImageGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_mem_image_get_info_params_t params = {&hMemory, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urMemImageGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urMemImageGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urMemImageGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerCreate +__urdlllocal ur_result_t UR_APICALL urSamplerCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_sampler_desc_t *pDesc, ///< [in] pointer to the sampler description + ur_sampler_handle_t + *phSampler ///< [out] pointer to handle of sampler object created +) { + auto pfnCreate = context.urDdiTable.Sampler.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_create_params_t params = {&hContext, &pDesc, &phSampler}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urSamplerCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urSamplerCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phSampler = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urSamplerCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerRetain +__urdlllocal ur_result_t UR_APICALL urSamplerRetain( + ur_sampler_handle_t + hSampler ///< [in][retain] handle of the sampler object to get access +) { + auto pfnRetain = context.urDdiTable.Sampler.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_retain_params_t params = {&hSampler}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urSamplerRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urSamplerRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hSampler); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urSamplerRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerRelease +__urdlllocal ur_result_t UR_APICALL urSamplerRelease( + ur_sampler_handle_t + hSampler ///< [in][release] handle of the sampler object to release +) { + auto pfnRelease = context.urDdiTable.Sampler.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_release_params_t params = {&hSampler}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urSamplerRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urSamplerRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hSampler); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urSamplerRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerGetInfo +__urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( + ur_sampler_handle_t hSampler, ///< [in] handle of the sampler object + ur_sampler_info_t propName, ///< [in] name of the sampler property to query + size_t + propSize, ///< [in] size in bytes of the sampler property value provided + void * + pPropValue, ///< [out][typename(propName, propSize)][optional] value of the sampler + ///< property + size_t * + pPropSizeRet ///< [out][optional] size in bytes returned in sampler property value +) { + auto pfnGetInfo = context.urDdiTable.Sampler.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_get_info_params_t params = {&hSampler, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urSamplerGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urSamplerGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urSamplerGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( + ur_sampler_handle_t hSampler, ///< [in] handle of the sampler. + ur_native_handle_t * + phNativeSampler ///< [out] a pointer to the native handle of the sampler. +) { + auto pfnGetNativeHandle = context.urDdiTable.Sampler.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_get_native_handle_params_t params = {&hSampler, + &phNativeSampler}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urSamplerGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urSamplerGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeSampler = reinterpret_cast(hSampler); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urSamplerGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urSamplerCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( + ur_native_handle_t + hNativeSampler, ///< [in][nocheck] the native handle of the sampler. + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_sampler_native_properties_t * + pProperties, ///< [in][optional] pointer to native sampler properties struct. + ur_sampler_handle_t * + phSampler ///< [out] pointer to the handle of the sampler object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Sampler.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_sampler_create_with_native_handle_params_t params = { + &hNativeSampler, &hContext, &pProperties, &phSampler}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urSamplerCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urSamplerCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phSampler = reinterpret_cast(hNativeSampler); + retainDummyHandle(*phSampler); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urSamplerCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMHostAlloc +__urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_usm_desc_t + *pUSMDesc, ///< [in][optional] USM memory allocation descriptor + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] minimum size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM host memory object +) { + auto pfnHostAlloc = context.urDdiTable.USM.pfnHostAlloc; + + if (nullptr == pfnHostAlloc) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_host_alloc_params_t params = {&hContext, &pUSMDesc, &pool, &size, + &ppMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMHostAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMHostAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMHostAlloc")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMDeviceAlloc +__urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t + *pUSMDesc, ///< [in][optional] USM memory allocation descriptor + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] minimum size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM device memory object +) { + auto pfnDeviceAlloc = context.urDdiTable.USM.pfnDeviceAlloc; + + if (nullptr == pfnDeviceAlloc) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_device_alloc_params_t params = {&hContext, &hDevice, &pUSMDesc, + &pool, &size, &ppMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMDeviceAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMDeviceAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMDeviceAlloc")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMSharedAlloc +__urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t * + pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor. + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] minimum size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM shared memory object +) { + auto pfnSharedAlloc = context.urDdiTable.USM.pfnSharedAlloc; + + if (nullptr == pfnSharedAlloc) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_shared_alloc_params_t params = {&hContext, &hDevice, &pUSMDesc, + &pool, &size, &ppMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMSharedAlloc")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMSharedAlloc")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMSharedAlloc")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMFree +__urdlllocal ur_result_t UR_APICALL urUSMFree( + ur_context_handle_t hContext, ///< [in] handle of the context object + void *pMem ///< [in] pointer to USM memory object +) { + auto pfnFree = context.urDdiTable.USM.pfnFree; + + if (nullptr == pfnFree) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_free_params_t params = {&hContext, &pMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMFree")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMFree")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMFree")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMGetMemAllocInfo +__urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( + ur_context_handle_t hContext, ///< [in] handle of the context object + const void *pMem, ///< [in] pointer to USM memory object + ur_usm_alloc_info_t + propName, ///< [in] the name of the USM allocation property to query + size_t + propSize, ///< [in] size in bytes of the USM allocation property value + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the USM + ///< allocation property + size_t * + pPropSizeRet ///< [out][optional] bytes returned in USM allocation property +) { + auto pfnGetMemAllocInfo = context.urDdiTable.USM.pfnGetMemAllocInfo; + + if (nullptr == pfnGetMemAllocInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_get_mem_alloc_info_params_t params = { + &hContext, &pMem, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMGetMemAllocInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMGetMemAllocInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMGetMemAllocInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMPoolCreate +__urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_usm_pool_desc_t * + pPoolDesc, ///< [in] pointer to USM pool descriptor. Can be chained with + ///< ::ur_usm_pool_limits_desc_t + ur_usm_pool_handle_t *ppPool ///< [out] pointer to USM memory pool +) { + auto pfnPoolCreate = context.urDdiTable.USM.pfnPoolCreate; + + if (nullptr == pfnPoolCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_pool_create_params_t params = {&hContext, &pPoolDesc, &ppPool}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMPoolCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMPoolCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *ppPool = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMPoolCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMPoolRetain +__urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool +) { + auto pfnPoolRetain = context.urDdiTable.USM.pfnPoolRetain; + + if (nullptr == pfnPoolRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_pool_retain_params_t params = {&pPool}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMPoolRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMPoolRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(pPool); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMPoolRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMPoolRelease +__urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool +) { + auto pfnPoolRelease = context.urDdiTable.USM.pfnPoolRelease; + + if (nullptr == pfnPoolRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_pool_release_params_t params = {&pPool}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMPoolRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMPoolRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(pPool); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMPoolRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMPoolGetInfo +__urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( + ur_usm_pool_handle_t hPool, ///< [in] handle of the USM memory pool + ur_usm_pool_info_t propName, ///< [in] name of the pool property to query + size_t propSize, ///< [in] size in bytes of the pool property value provided + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the pool + ///< property + size_t * + pPropSizeRet ///< [out][optional] size in bytes returned in pool property value +) { + auto pfnPoolGetInfo = context.urDdiTable.USM.pfnPoolGetInfo; + + if (nullptr == pfnPoolGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_pool_get_info_params_t params = {&hPool, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMPoolGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMPoolGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMPoolGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemGranularityGetInfo +__urdlllocal ur_result_t UR_APICALL urVirtualMemGranularityGetInfo( + ur_context_handle_t hContext, ///< [in] handle of the context object. + ur_device_handle_t + hDevice, ///< [in][optional] is the device to get the granularity from, if the + ///< device is null then the granularity is suitable for all devices in context. + ur_virtual_mem_granularity_info_t + propName, ///< [in] type of the info to query. + size_t + propSize, ///< [in] size in bytes of the memory pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. If propSize is less than the real number of bytes needed to + ///< return the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is + ///< returned and pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName." +) { + auto pfnGranularityGetInfo = + context.urDdiTable.VirtualMem.pfnGranularityGetInfo; + + if (nullptr == pfnGranularityGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_granularity_get_info_params_t params = { + &hContext, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urVirtualMemGranularityGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urVirtualMemGranularityGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urVirtualMemGranularityGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemReserve +__urdlllocal ur_result_t UR_APICALL urVirtualMemReserve( + ur_context_handle_t hContext, ///< [in] handle of the context object. + const void * + pStart, ///< [in][optional] pointer to the start of the virtual memory region to + ///< reserve, specifying a null value causes the implementation to select a + ///< start address. + size_t + size, ///< [in] size in bytes of the virtual address range to reserve. + void ** + ppStart ///< [out] pointer to the returned address at the start of reserved virtual + ///< memory range. +) { + auto pfnReserve = context.urDdiTable.VirtualMem.pfnReserve; + + if (nullptr == pfnReserve) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_reserve_params_t params = {&hContext, &pStart, &size, + &ppStart}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemReserve")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemReserve")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemReserve")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemFree +__urdlllocal ur_result_t UR_APICALL urVirtualMemFree( + ur_context_handle_t hContext, ///< [in] handle of the context object. + const void * + pStart, ///< [in] pointer to the start of the virtual memory range to free. + size_t size ///< [in] size in bytes of the virtual memory range to free. +) { + auto pfnFree = context.urDdiTable.VirtualMem.pfnFree; + + if (nullptr == pfnFree) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_free_params_t params = {&hContext, &pStart, &size}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemFree")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemFree")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemFree")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemMap +__urdlllocal ur_result_t UR_APICALL urVirtualMemMap( + ur_context_handle_t hContext, ///< [in] handle to the context object. + const void + *pStart, ///< [in] pointer to the start of the virtual memory range. + size_t size, ///< [in] size in bytes of the virtual memory range to map. + ur_physical_mem_handle_t + hPhysicalMem, ///< [in] handle of the physical memory to map pStart to. + size_t + offset, ///< [in] offset in bytes into the physical memory to map pStart to. + ur_virtual_mem_access_flags_t + flags ///< [in] access flags for the physical memory mapping. +) { + auto pfnMap = context.urDdiTable.VirtualMem.pfnMap; + + if (nullptr == pfnMap) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_map_params_t params = {&hContext, &pStart, &size, + &hPhysicalMem, &offset, &flags}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemMap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemMap")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemMap")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemUnmap +__urdlllocal ur_result_t UR_APICALL urVirtualMemUnmap( + ur_context_handle_t hContext, ///< [in] handle to the context object. + const void * + pStart, ///< [in] pointer to the start of the mapped virtual memory range + size_t size ///< [in] size in bytes of the virtual memory range. +) { + auto pfnUnmap = context.urDdiTable.VirtualMem.pfnUnmap; + + if (nullptr == pfnUnmap) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_unmap_params_t params = {&hContext, &pStart, &size}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemUnmap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemUnmap")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemUnmap")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemSetAccess +__urdlllocal ur_result_t UR_APICALL urVirtualMemSetAccess( + ur_context_handle_t hContext, ///< [in] handle to the context object. + const void + *pStart, ///< [in] pointer to the start of the virtual memory range. + size_t size, ///< [in] size in bytes of the virtual memory range. + ur_virtual_mem_access_flags_t + flags ///< [in] access flags to set for the mapped virtual memory range. +) { + auto pfnSetAccess = context.urDdiTable.VirtualMem.pfnSetAccess; + + if (nullptr == pfnSetAccess) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_set_access_params_t params = {&hContext, &pStart, &size, + &flags}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemSetAccess")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemSetAccess")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemSetAccess")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urVirtualMemGetInfo +__urdlllocal ur_result_t UR_APICALL urVirtualMemGetInfo( + ur_context_handle_t hContext, ///< [in] handle to the context object. + const void + *pStart, ///< [in] pointer to the start of the virtual memory range. + size_t size, ///< [in] size in bytes of the virtual memory range. + ur_virtual_mem_info_t propName, ///< [in] type of the info to query. + size_t + propSize, ///< [in] size in bytes of the memory pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. If propSize is less than the real number of bytes needed to + ///< return the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is + ///< returned and pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName." +) { + auto pfnGetInfo = context.urDdiTable.VirtualMem.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_virtual_mem_get_info_params_t params = { + &hContext, &pStart, &size, &propName, + &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urVirtualMemGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urVirtualMemGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urVirtualMemGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPhysicalMemCreate +__urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object. + ur_device_handle_t hDevice, ///< [in] handle of the device object. + size_t + size, ///< [in] size in bytes of physical memory to allocate, must be a multiple + ///< of ::UR_VIRTUAL_MEM_GRANULARITY_INFO_MINIMUM. + const ur_physical_mem_properties_t * + pProperties, ///< [in][optional] pointer to physical memory creation properties. + ur_physical_mem_handle_t * + phPhysicalMem ///< [out] pointer to handle of physical memory object created. +) { + auto pfnCreate = context.urDdiTable.PhysicalMem.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_physical_mem_create_params_t params = {&hContext, &hDevice, &size, + &pProperties, &phPhysicalMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPhysicalMemCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPhysicalMemCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phPhysicalMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPhysicalMemCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPhysicalMemRetain +__urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( + ur_physical_mem_handle_t + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. +) { + auto pfnRetain = context.urDdiTable.PhysicalMem.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_physical_mem_retain_params_t params = {&hPhysicalMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPhysicalMemRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPhysicalMemRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hPhysicalMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPhysicalMemRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urPhysicalMemRelease +__urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( + ur_physical_mem_handle_t + hPhysicalMem ///< [in][release] handle of the physical memory object to release. +) { + auto pfnRelease = context.urDdiTable.PhysicalMem.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_physical_mem_release_params_t params = {&hPhysicalMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urPhysicalMemRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urPhysicalMemRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hPhysicalMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urPhysicalMemRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithIL +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( + ur_context_handle_t hContext, ///< [in] handle of the context instance + const void *pIL, ///< [in] pointer to IL binary. + size_t length, ///< [in] length of `pIL` in bytes. + const ur_program_properties_t * + pProperties, ///< [in][optional] pointer to program creation properties. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnCreateWithIL = context.urDdiTable.Program.pfnCreateWithIL; + + if (nullptr == pfnCreateWithIL) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_create_with_il_params_t params = {&hContext, &pIL, &length, + &pProperties, &phProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramCreateWithIL")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramCreateWithIL")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phProgram = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramCreateWithIL")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithBinary +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( + ur_context_handle_t hContext, ///< [in] handle of the context instance + ur_device_handle_t + hDevice, ///< [in] handle to device associated with binary. + size_t size, ///< [in] size in bytes. + const uint8_t *pBinary, ///< [in] pointer to binary. + const ur_program_properties_t * + pProperties, ///< [in][optional] pointer to program creation properties. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of Program object created. +) { + auto pfnCreateWithBinary = context.urDdiTable.Program.pfnCreateWithBinary; + + if (nullptr == pfnCreateWithBinary) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_create_with_binary_params_t params = { + &hContext, &hDevice, &size, &pBinary, &pProperties, &phProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramCreateWithBinary")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramCreateWithBinary")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phProgram = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramCreateWithBinary")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramBuild +__urdlllocal ur_result_t UR_APICALL urProgramBuild( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + ur_program_handle_t hProgram, ///< [in] Handle of the program to build. + const char * + pOptions ///< [in][optional] pointer to build options null-terminated string. +) { + auto pfnBuild = context.urDdiTable.Program.pfnBuild; + + if (nullptr == pfnBuild) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_build_params_t params = {&hContext, &hProgram, &pOptions}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramBuild")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramBuild")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramBuild")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCompile +__urdlllocal ur_result_t UR_APICALL urProgramCompile( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + ur_program_handle_t + hProgram, ///< [in][out] handle of the program to compile. + const char * + pOptions ///< [in][optional] pointer to build options null-terminated string. +) { + auto pfnCompile = context.urDdiTable.Program.pfnCompile; + + if (nullptr == pfnCompile) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_compile_params_t params = {&hContext, &hProgram, &pOptions}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramCompile")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramCompile")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramCompile")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramLink +__urdlllocal ur_result_t UR_APICALL urProgramLink( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + uint32_t count, ///< [in] number of program handles in `phPrograms`. + const ur_program_handle_t * + phPrograms, ///< [in][range(0, count)] pointer to array of program handles. + const char * + pOptions, ///< [in][optional] pointer to linker options null-terminated string. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnLink = context.urDdiTable.Program.pfnLink; + + if (nullptr == pfnLink) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_link_params_t params = {&hContext, &count, &phPrograms, + &pOptions, &phProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramLink")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramLink")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phProgram = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramLink")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramRetain +__urdlllocal ur_result_t UR_APICALL urProgramRetain( + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain +) { + auto pfnRetain = context.urDdiTable.Program.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_retain_params_t params = {&hProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hProgram); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramRelease +__urdlllocal ur_result_t UR_APICALL urProgramRelease( + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release +) { + auto pfnRelease = context.urDdiTable.Program.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_release_params_t params = {&hProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hProgram); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramGetFunctionPointer +__urdlllocal ur_result_t UR_APICALL urProgramGetFunctionPointer( + ur_device_handle_t + hDevice, ///< [in] handle of the device to retrieve pointer for. + ur_program_handle_t + hProgram, ///< [in] handle of the program to search for function in. + ///< The program must already be built to the specified device, or + ///< otherwise ::UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE is returned. + const char * + pFunctionName, ///< [in] A null-terminates string denoting the mangled function name. + void ** + ppFunctionPointer ///< [out] Returns the pointer to the function if it is found in the program. +) { + auto pfnGetFunctionPointer = + context.urDdiTable.Program.pfnGetFunctionPointer; + + if (nullptr == pfnGetFunctionPointer) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_get_function_pointer_params_t params = { + &hDevice, &hProgram, &pFunctionName, &ppFunctionPointer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urProgramGetFunctionPointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urProgramGetFunctionPointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramGetFunctionPointer")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramGetGlobalVariablePointer +__urdlllocal ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( + ur_device_handle_t + hDevice, ///< [in] handle of the device to retrieve the pointer for. + ur_program_handle_t + hProgram, ///< [in] handle of the program where the global variable is. + const char * + pGlobalVariableName, ///< [in] mangled name of the global variable to retrieve the pointer for. + size_t * + pGlobalVariableSizeRet, ///< [out][optional] Returns the size of the global variable if it is found + ///< in the program. + void ** + ppGlobalVariablePointerRet ///< [out] Returns the pointer to the global variable if it is found in the program. +) { + auto pfnGetGlobalVariablePointer = + context.urDdiTable.Program.pfnGetGlobalVariablePointer; + + if (nullptr == pfnGetGlobalVariablePointer) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_get_global_variable_pointer_params_t params = { + &hDevice, &hProgram, &pGlobalVariableName, &pGlobalVariableSizeRet, + &ppGlobalVariablePointerRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urProgramGetGlobalVariablePointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urProgramGetGlobalVariablePointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urProgramGetGlobalVariablePointer")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramGetInfo +__urdlllocal ur_result_t UR_APICALL urProgramGetInfo( + ur_program_handle_t hProgram, ///< [in] handle of the Program object + ur_program_info_t propName, ///< [in] name of the Program property to query + size_t propSize, ///< [in] the size of the Program property. + void * + pPropValue, ///< [in,out][optional][typename(propName, propSize)] array of bytes of + ///< holding the program info property. + ///< If propSize is not equal to or greater than the real number of bytes + ///< needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnGetInfo = context.urDdiTable.Program.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_get_info_params_t params = {&hProgram, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramGetBuildInfo +__urdlllocal ur_result_t UR_APICALL urProgramGetBuildInfo( + ur_program_handle_t hProgram, ///< [in] handle of the Program object + ur_device_handle_t hDevice, ///< [in] handle of the Device object + ur_program_build_info_t + propName, ///< [in] name of the Program build info to query + size_t propSize, ///< [in] size of the Program build info property. + void * + pPropValue, ///< [in,out][optional][typename(propName, propSize)] value of the Program + ///< build property. + ///< If propSize is not equal to or greater than the real number of bytes + ///< needed to return the info then the ::UR_RESULT_ERROR_INVALID_SIZE + ///< error is returned and pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of data being + ///< queried by propName. +) { + auto pfnGetBuildInfo = context.urDdiTable.Program.pfnGetBuildInfo; + + if (nullptr == pfnGetBuildInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_get_build_info_params_t params = { + &hProgram, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramGetBuildInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramGetBuildInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramGetBuildInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramSetSpecializationConstants +__urdlllocal ur_result_t UR_APICALL urProgramSetSpecializationConstants( + ur_program_handle_t hProgram, ///< [in] handle of the Program object + uint32_t count, ///< [in] the number of elements in the pSpecConstants array + const ur_specialization_constant_info_t * + pSpecConstants ///< [in][range(0, count)] array of specialization constant value + ///< descriptions +) { + auto pfnSetSpecializationConstants = + context.urDdiTable.Program.pfnSetSpecializationConstants; + + if (nullptr == pfnSetSpecializationConstants) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_set_specialization_constants_params_t params = { + &hProgram, &count, &pSpecConstants}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urProgramSetSpecializationConstants")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urProgramSetSpecializationConstants")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urProgramSetSpecializationConstants")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( + ur_program_handle_t hProgram, ///< [in] handle of the program. + ur_native_handle_t * + phNativeProgram ///< [out] a pointer to the native handle of the program. +) { + auto pfnGetNativeHandle = context.urDdiTable.Program.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_get_native_handle_params_t params = {&hProgram, + &phNativeProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeProgram = reinterpret_cast(hProgram); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( + ur_native_handle_t + hNativeProgram, ///< [in][nocheck] the native handle of the program. + ur_context_handle_t hContext, ///< [in] handle of the context instance + const ur_program_native_properties_t * + pProperties, ///< [in][optional] pointer to native program properties struct. + ur_program_handle_t * + phProgram ///< [out] pointer to the handle of the program object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Program.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_create_with_native_handle_params_t params = { + &hNativeProgram, &hContext, &pProperties, &phProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urProgramCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urProgramCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phProgram = reinterpret_cast(hNativeProgram); + retainDummyHandle(*phProgram); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urProgramCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelCreate +__urdlllocal ur_result_t UR_APICALL urKernelCreate( + ur_program_handle_t hProgram, ///< [in] handle of the program instance + const char *pKernelName, ///< [in] pointer to null-terminated string. + ur_kernel_handle_t + *phKernel ///< [out] pointer to handle of kernel object created. +) { + auto pfnCreate = context.urDdiTable.Kernel.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_create_params_t params = {&hProgram, &pKernelName, &phKernel}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phKernel = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgValue +__urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + size_t argSize, ///< [in] size of argument type + const ur_kernel_arg_value_properties_t + *pProperties, ///< [in][optional] pointer to value properties. + const void + *pArgValue ///< [in] argument value represented as matching arg type. +) { + auto pfnSetArgValue = context.urDdiTable.Kernel.pfnSetArgValue; + + if (nullptr == pfnSetArgValue) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_arg_value_params_t params = {&hKernel, &argIndex, &argSize, + &pProperties, &pArgValue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetArgValue")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetArgValue")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetArgValue")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgLocal +__urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + size_t + argSize, ///< [in] size of the local buffer to be allocated by the runtime + const ur_kernel_arg_local_properties_t + *pProperties ///< [in][optional] pointer to local buffer properties. +) { + auto pfnSetArgLocal = context.urDdiTable.Kernel.pfnSetArgLocal; + + if (nullptr == pfnSetArgLocal) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_arg_local_params_t params = {&hKernel, &argIndex, &argSize, + &pProperties}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetArgLocal")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetArgLocal")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetArgLocal")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelGetInfo +__urdlllocal ur_result_t UR_APICALL urKernelGetInfo( + ur_kernel_handle_t hKernel, ///< [in] handle of the Kernel object + ur_kernel_info_t propName, ///< [in] name of the Kernel property to query + size_t propSize, ///< [in] the size of the Kernel property value. + void * + pPropValue, ///< [in,out][optional][typename(propName, propSize)] array of bytes + ///< holding the kernel info property. + ///< If propSize is not equal to or greater than the real number of bytes + ///< needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of data being + ///< queried by propName. +) { + auto pfnGetInfo = context.urDdiTable.Kernel.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_get_info_params_t params = {&hKernel, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelGetGroupInfo +__urdlllocal ur_result_t UR_APICALL urKernelGetGroupInfo( + ur_kernel_handle_t hKernel, ///< [in] handle of the Kernel object + ur_device_handle_t hDevice, ///< [in] handle of the Device object + ur_kernel_group_info_t + propName, ///< [in] name of the work Group property to query + size_t propSize, ///< [in] size of the Kernel Work Group property value + void * + pPropValue, ///< [in,out][optional][typename(propName, propSize)] value of the Kernel + ///< Work Group property. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of data being + ///< queried by propName. +) { + auto pfnGetGroupInfo = context.urDdiTable.Kernel.pfnGetGroupInfo; + + if (nullptr == pfnGetGroupInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_get_group_info_params_t params = { + &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelGetGroupInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelGetGroupInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelGetGroupInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelGetSubGroupInfo +__urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( + ur_kernel_handle_t hKernel, ///< [in] handle of the Kernel object + ur_device_handle_t hDevice, ///< [in] handle of the Device object + ur_kernel_sub_group_info_t + propName, ///< [in] name of the SubGroup property to query + size_t propSize, ///< [in] size of the Kernel SubGroup property value + void * + pPropValue, ///< [in,out][optional][typename(propName, propSize)] value of the Kernel + ///< SubGroup property. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of data being + ///< queried by propName. +) { + auto pfnGetSubGroupInfo = context.urDdiTable.Kernel.pfnGetSubGroupInfo; + + if (nullptr == pfnGetSubGroupInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_get_sub_group_info_params_t params = { + &hKernel, &hDevice, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelGetSubGroupInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelGetSubGroupInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelGetSubGroupInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelRetain +__urdlllocal ur_result_t UR_APICALL urKernelRetain( + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain +) { + auto pfnRetain = context.urDdiTable.Kernel.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_retain_params_t params = {&hKernel}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hKernel); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelRelease +__urdlllocal ur_result_t UR_APICALL urKernelRelease( + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release +) { + auto pfnRelease = context.urDdiTable.Kernel.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_release_params_t params = {&hKernel}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hKernel); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgPointer +__urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + const ur_kernel_arg_pointer_properties_t + *pProperties, ///< [in][optional] pointer to USM pointer properties. + const void * + pArgValue ///< [in][optional] USM pointer to memory location holding the argument + ///< value. If null then argument value is considered null. +) { + auto pfnSetArgPointer = context.urDdiTable.Kernel.pfnSetArgPointer; + + if (nullptr == pfnSetArgPointer) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_arg_pointer_params_t params = {&hKernel, &argIndex, + &pProperties, &pArgValue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetArgPointer")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetArgPointer")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetArgPointer")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetExecInfo +__urdlllocal ur_result_t UR_APICALL urKernelSetExecInfo( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + ur_kernel_exec_info_t propName, ///< [in] name of the execution attribute + size_t propSize, ///< [in] size in byte the attribute value + const ur_kernel_exec_info_properties_t + *pProperties, ///< [in][optional] pointer to execution info properties. + const void * + pPropValue ///< [in][typename(propName, propSize)] pointer to memory location holding + ///< the property value. +) { + auto pfnSetExecInfo = context.urDdiTable.Kernel.pfnSetExecInfo; + + if (nullptr == pfnSetExecInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_exec_info_params_t params = {&hKernel, &propName, &propSize, + &pProperties, &pPropValue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetExecInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetExecInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetExecInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgSampler +__urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + const ur_kernel_arg_sampler_properties_t + *pProperties, ///< [in][optional] pointer to sampler properties. + ur_sampler_handle_t hArgValue ///< [in] handle of Sampler object. +) { + auto pfnSetArgSampler = context.urDdiTable.Kernel.pfnSetArgSampler; + + if (nullptr == pfnSetArgSampler) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_arg_sampler_params_t params = {&hKernel, &argIndex, + &pProperties, &hArgValue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetArgSampler")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetArgSampler")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetArgSampler")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgMemObj +__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + const ur_kernel_arg_mem_obj_properties_t + *pProperties, ///< [in][optional] pointer to Memory object properties. + ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object. +) { + auto pfnSetArgMemObj = context.urDdiTable.Kernel.pfnSetArgMemObj; + + if (nullptr == pfnSetArgMemObj) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_arg_mem_obj_params_t params = {&hKernel, &argIndex, + &pProperties, &hArgValue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelSetArgMemObj")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelSetArgMemObj")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelSetArgMemObj")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetSpecializationConstants +__urdlllocal ur_result_t UR_APICALL urKernelSetSpecializationConstants( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t count, ///< [in] the number of elements in the pSpecConstants array + const ur_specialization_constant_info_t * + pSpecConstants ///< [in] array of specialization constant value descriptions +) { + auto pfnSetSpecializationConstants = + context.urDdiTable.Kernel.pfnSetSpecializationConstants; + + if (nullptr == pfnSetSpecializationConstants) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_set_specialization_constants_params_t params = {&hKernel, &count, + &pSpecConstants}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urKernelSetSpecializationConstants")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urKernelSetSpecializationConstants")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urKernelSetSpecializationConstants")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel. + ur_native_handle_t + *phNativeKernel ///< [out] a pointer to the native handle of the kernel. +) { + auto pfnGetNativeHandle = context.urDdiTable.Kernel.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_get_native_handle_params_t params = {&hKernel, &phNativeKernel}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urKernelGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urKernelGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeKernel = reinterpret_cast(hKernel); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urKernelGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( + ur_native_handle_t + hNativeKernel, ///< [in][nocheck] the native handle of the kernel. + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_program_handle_t + hProgram, ///< [in] handle of the program associated with the kernel + const ur_kernel_native_properties_t * + pProperties, ///< [in][optional] pointer to native kernel properties struct + ur_kernel_handle_t + *phKernel ///< [out] pointer to the handle of the kernel object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Kernel.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_create_with_native_handle_params_t params = { + &hNativeKernel, &hContext, &hProgram, &pProperties, &phKernel}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urKernelCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urKernelCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phKernel = reinterpret_cast(hNativeKernel); + retainDummyHandle(*phKernel); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urKernelCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelGetSuggestedLocalWorkSize +__urdlllocal ur_result_t UR_APICALL urKernelGetSuggestedLocalWorkSize( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + uint32_t + numWorkDim, ///< [in] number of dimensions, from 1 to 3, to specify the global + ///< and work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of numWorkDim unsigned values that specify + ///< the offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of numWorkDim unsigned values that specify + ///< the number of global work-items in workDim that will execute the + ///< kernel function + size_t * + pSuggestedLocalWorkSize ///< [out] pointer to an array of numWorkDim unsigned values that specify + ///< suggested local work size that will contain the result of the query +) { + auto pfnGetSuggestedLocalWorkSize = + context.urDdiTable.Kernel.pfnGetSuggestedLocalWorkSize; + + if (nullptr == pfnGetSuggestedLocalWorkSize) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_get_suggested_local_work_size_params_t params = { + &hKernel, &hQueue, &numWorkDim, + &pGlobalWorkOffset, &pGlobalWorkSize, &pSuggestedLocalWorkSize}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urKernelGetSuggestedLocalWorkSize")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueGetInfo +__urdlllocal ur_result_t UR_APICALL urQueueGetInfo( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_queue_info_t propName, ///< [in] name of the queue property to query + size_t + propSize, ///< [in] size in bytes of the queue property value provided + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the queue + ///< property + size_t * + pPropSizeRet ///< [out][optional] size in bytes returned in queue property value +) { + auto pfnGetInfo = context.urDdiTable.Queue.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_get_info_params_t params = {&hQueue, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueCreate +__urdlllocal ur_result_t UR_APICALL urQueueCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_queue_properties_t + *pProperties, ///< [in][optional] pointer to queue creation properties. + ur_queue_handle_t + *phQueue ///< [out] pointer to handle of queue object created +) { + auto pfnCreate = context.urDdiTable.Queue.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_create_params_t params = {&hContext, &hDevice, &pProperties, + &phQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueCreate")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueCreate")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phQueue = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueCreate")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueRetain +__urdlllocal ur_result_t UR_APICALL urQueueRetain( + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access +) { + auto pfnRetain = context.urDdiTable.Queue.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_retain_params_t params = {&hQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hQueue); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueRelease +__urdlllocal ur_result_t UR_APICALL urQueueRelease( + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release +) { + auto pfnRelease = context.urDdiTable.Queue.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_release_params_t params = {&hQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hQueue); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( + ur_queue_handle_t hQueue, ///< [in] handle of the queue. + ur_queue_native_desc_t + *pDesc, ///< [in][optional] pointer to native descriptor + ur_native_handle_t + *phNativeQueue ///< [out] a pointer to the native handle of the queue. +) { + auto pfnGetNativeHandle = context.urDdiTable.Queue.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_get_native_handle_params_t params = {&hQueue, &pDesc, + &phNativeQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeQueue = reinterpret_cast(hQueue); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( + ur_native_handle_t + hNativeQueue, ///< [in][nocheck] the native handle of the queue. + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_queue_native_properties_t * + pProperties, ///< [in][optional] pointer to native queue properties struct + ur_queue_handle_t + *phQueue ///< [out] pointer to the handle of the queue object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Queue.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_create_with_native_handle_params_t params = { + &hNativeQueue, &hContext, &hDevice, &pProperties, &phQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urQueueCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urQueueCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phQueue = reinterpret_cast(hNativeQueue); + retainDummyHandle(*phQueue); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urQueueCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueFinish +__urdlllocal ur_result_t UR_APICALL urQueueFinish( + ur_queue_handle_t hQueue ///< [in] handle of the queue to be finished. +) { + auto pfnFinish = context.urDdiTable.Queue.pfnFinish; + + if (nullptr == pfnFinish) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_finish_params_t params = {&hQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueFinish")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueFinish")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueFinish")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urQueueFlush +__urdlllocal ur_result_t UR_APICALL urQueueFlush( + ur_queue_handle_t hQueue ///< [in] handle of the queue to be flushed. +) { + auto pfnFlush = context.urDdiTable.Queue.pfnFlush; + + if (nullptr == pfnFlush) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_queue_flush_params_t params = {&hQueue}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urQueueFlush")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urQueueFlush")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urQueueFlush")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventGetInfo +__urdlllocal ur_result_t UR_APICALL urEventGetInfo( + ur_event_handle_t hEvent, ///< [in] handle of the event object + ur_event_info_t propName, ///< [in] the name of the event property to query + size_t propSize, ///< [in] size in bytes of the event property value + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the event + ///< property + size_t *pPropSizeRet ///< [out][optional] bytes returned in event property +) { + auto pfnGetInfo = context.urDdiTable.Event.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_get_info_params_t params = {&hEvent, &propName, &propSize, + &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventGetInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventGetInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventGetInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventGetProfilingInfo +__urdlllocal ur_result_t UR_APICALL urEventGetProfilingInfo( + ur_event_handle_t hEvent, ///< [in] handle of the event object + ur_profiling_info_t + propName, ///< [in] the name of the profiling property to query + size_t propSize, ///< [in] size in bytes of the profiling property value + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the profiling + ///< property + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes returned in + ///< propValue +) { + auto pfnGetProfilingInfo = context.urDdiTable.Event.pfnGetProfilingInfo; + + if (nullptr == pfnGetProfilingInfo) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_get_profiling_info_params_t params = { + &hEvent, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventGetProfilingInfo")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventGetProfilingInfo")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventGetProfilingInfo")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventWait +__urdlllocal ur_result_t UR_APICALL urEventWait( + uint32_t numEvents, ///< [in] number of events in the event list + const ur_event_handle_t * + phEventWaitList ///< [in][range(0, numEvents)] pointer to a list of events to wait for + ///< completion +) { + auto pfnWait = context.urDdiTable.Event.pfnWait; + + if (nullptr == pfnWait) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_wait_params_t params = {&numEvents, &phEventWaitList}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventWait")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventWait")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventWait")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventRetain +__urdlllocal ur_result_t UR_APICALL urEventRetain( + ur_event_handle_t hEvent ///< [in][retain] handle of the event object +) { + auto pfnRetain = context.urDdiTable.Event.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_retain_params_t params = {&hEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventRetain")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventRetain")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hEvent); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventRetain")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventRelease +__urdlllocal ur_result_t UR_APICALL urEventRelease( + ur_event_handle_t hEvent ///< [in][release] handle of the event object +) { + auto pfnRelease = context.urDdiTable.Event.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_release_params_t params = {&hEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventRelease")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventRelease")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hEvent); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventRelease")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( + ur_event_handle_t hEvent, ///< [in] handle of the event. + ur_native_handle_t + *phNativeEvent ///< [out] a pointer to the native handle of the event. +) { + auto pfnGetNativeHandle = context.urDdiTable.Event.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_get_native_handle_params_t params = {&hEvent, &phNativeEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventGetNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventGetNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phNativeEvent = reinterpret_cast(hEvent); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventGetNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( + ur_native_handle_t + hNativeEvent, ///< [in][nocheck] the native handle of the event. + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_event_native_properties_t * + pProperties, ///< [in][optional] pointer to native event properties struct + ur_event_handle_t + *phEvent ///< [out] pointer to the handle of the event object created. +) { + auto pfnCreateWithNativeHandle = + context.urDdiTable.Event.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_create_with_native_handle_params_t params = { + &hNativeEvent, &hContext, &pProperties, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEventCreateWithNativeHandle")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEventCreateWithNativeHandle")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phEvent = reinterpret_cast(hNativeEvent); + retainDummyHandle(*phEvent); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEventCreateWithNativeHandle")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEventSetCallback +__urdlllocal ur_result_t UR_APICALL urEventSetCallback( + ur_event_handle_t hEvent, ///< [in] handle of the event object + ur_execution_info_t execStatus, ///< [in] execution status of the event + ur_event_callback_t pfnNotify, ///< [in] execution status of the event + void * + pUserData ///< [in][out][optional] pointer to data to be passed to callback. +) { + auto pfnSetCallback = context.urDdiTable.Event.pfnSetCallback; + + if (nullptr == pfnSetCallback) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_event_set_callback_params_t params = {&hEvent, &execStatus, &pfnNotify, + &pUserData}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEventSetCallback")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEventSetCallback")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEventSetCallback")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueKernelLaunch +__urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnKernelLaunch = context.urDdiTable.Enqueue.pfnKernelLaunch; + + if (nullptr == pfnKernelLaunch) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_kernel_launch_params_t params = {&hQueue, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueKernelLaunch")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueKernelLaunch")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueKernelLaunch")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueEventsWait +__urdlllocal ur_result_t UR_APICALL urEnqueueEventsWait( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that all + ///< previously enqueued commands + ///< must be complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnEventsWait = context.urDdiTable.Enqueue.pfnEventsWait; + + if (nullptr == pfnEventsWait) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_events_wait_params_t params = {&hQueue, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueEventsWait")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueEventsWait")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueEventsWait")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueEventsWaitWithBarrier +__urdlllocal ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that all + ///< previously enqueued commands + ///< must be complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnEventsWaitWithBarrier = + context.urDdiTable.Enqueue.pfnEventsWaitWithBarrier; + + if (nullptr == pfnEventsWaitWithBarrier) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_events_wait_with_barrier_params_t params = { + &hQueue, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueEventsWaitWithBarrier")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueEventsWaitWithBarrier")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueEventsWaitWithBarrier")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferRead +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) + size_t offset, ///< [in] offset in bytes in the buffer object + size_t size, ///< [in] size in bytes of data being read + void *pDst, ///< [in] pointer to host memory where data is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferRead = context.urDdiTable.Enqueue.pfnMemBufferRead; + + if (nullptr == pfnMemBufferRead) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_read_params_t params = { + &hQueue, &hBuffer, &blockingRead, &offset, + &size, &pDst, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemBufferRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferRead")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferWrite +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool + blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) + size_t offset, ///< [in] offset in bytes in the buffer object + size_t size, ///< [in] size in bytes of data being written + const void + *pSrc, ///< [in] pointer to host memory where data is to be written from + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferWrite = context.urDdiTable.Enqueue.pfnMemBufferWrite; + + if (nullptr == pfnMemBufferWrite) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_write_params_t params = { + &hQueue, &hBuffer, &blockingWrite, &offset, + &size, &pSrc, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemBufferWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferWrite")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferReadRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object + bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer + ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth + size_t + bufferRowPitch, ///< [in] length of each row in bytes in the buffer object + size_t + bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being read + size_t + hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by + ///< dst + size_t + hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region + ///< pointed by dst + void *pDst, ///< [in] pointer to host memory where data is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferReadRect = context.urDdiTable.Enqueue.pfnMemBufferReadRect; + + if (nullptr == pfnMemBufferReadRect) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_read_rect_params_t params = {&hQueue, + &hBuffer, + &blockingRead, + &bufferOrigin, + &hostOrigin, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pDst, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferReadRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueMemBufferReadRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferReadRect")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferWriteRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object + bool + blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer + ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth + size_t + bufferRowPitch, ///< [in] length of each row in bytes in the buffer object + size_t + bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being + ///< written + size_t + hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by + ///< src + size_t + hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region + ///< pointed by src + void + *pSrc, ///< [in] pointer to host memory where data is to be written from + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] points to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferWriteRect = + context.urDdiTable.Enqueue.pfnMemBufferWriteRect; + + if (nullptr == pfnMemBufferWriteRect) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_write_rect_params_t params = {&hQueue, + &hBuffer, + &blockingWrite, + &bufferOrigin, + &hostOrigin, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pSrc, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueMemBufferWriteRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueMemBufferWriteRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferWriteRect")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferCopy +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBufferSrc, ///< [in][bounds(srcOffset, size)] handle of the src buffer object + ur_mem_handle_t + hBufferDst, ///< [in][bounds(dstOffset, size)] handle of the dest buffer object + size_t srcOffset, ///< [in] offset into hBufferSrc to begin copying from + size_t dstOffset, ///< [in] offset info hBufferDst to begin copying into + size_t size, ///< [in] size in bytes of data being copied + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferCopy = context.urDdiTable.Enqueue.pfnMemBufferCopy; + + if (nullptr == pfnMemBufferCopy) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_copy_params_t params = { + &hQueue, &hBufferSrc, &hBufferDst, &srcOffset, &dstOffset, + &size, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferCopy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemBufferCopy")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferCopy")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferCopyRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBufferSrc, ///< [in][bounds(srcOrigin, region)] handle of the source buffer object + ur_mem_handle_t + hBufferDst, ///< [in][bounds(dstOrigin, region)] handle of the dest buffer object + ur_rect_offset_t srcOrigin, ///< [in] 3D offset in the source buffer + ur_rect_offset_t dstOrigin, ///< [in] 3D offset in the destination buffer + ur_rect_region_t + region, ///< [in] source 3D rectangular region descriptor: width, height, depth + size_t + srcRowPitch, ///< [in] length of each row in bytes in the source buffer object + size_t + srcSlicePitch, ///< [in] length of each 2D slice in bytes in the source buffer object + size_t + dstRowPitch, ///< [in] length of each row in bytes in the destination buffer object + size_t + dstSlicePitch, ///< [in] length of each 2D slice in bytes in the destination buffer object + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferCopyRect = context.urDdiTable.Enqueue.pfnMemBufferCopyRect; + + if (nullptr == pfnMemBufferCopyRect) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_copy_rect_params_t params = { + &hQueue, &hBufferSrc, &hBufferDst, &srcOrigin, + &dstOrigin, ®ion, &srcRowPitch, &srcSlicePitch, + &dstRowPitch, &dstSlicePitch, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferCopyRect")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueMemBufferCopyRect")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferCopyRect")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferFill +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + const void *pPattern, ///< [in] pointer to the fill pattern + size_t patternSize, ///< [in] size in bytes of the pattern + size_t offset, ///< [in] offset into the buffer + size_t size, ///< [in] fill size in bytes, must be a multiple of patternSize + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferFill = context.urDdiTable.Enqueue.pfnMemBufferFill; + + if (nullptr == pfnMemBufferFill) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_fill_params_t params = {&hQueue, + &hBuffer, + &pPattern, + &patternSize, + &offset, + &size, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferFill")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemBufferFill")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferFill")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemImageRead +__urdlllocal ur_result_t UR_APICALL urEnqueueMemImageRead( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hImage, ///< [in][bounds(origin, region)] handle of the image object + bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t + origin, ///< [in] defines the (x,y,z) offset in pixels in the 1D, 2D, or 3D image + ur_rect_region_t + region, ///< [in] defines the (width, height, depth) in pixels of the 1D, 2D, or 3D + ///< image + size_t rowPitch, ///< [in] length of each row in bytes + size_t slicePitch, ///< [in] length of each 2D slice of the 3D image + void *pDst, ///< [in] pointer to host memory where image is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemImageRead = context.urDdiTable.Enqueue.pfnMemImageRead; + + if (nullptr == pfnMemImageRead) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_image_read_params_t params = { + &hQueue, &hImage, &blockingRead, + &origin, ®ion, &rowPitch, + &slicePitch, &pDst, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemImageRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemImageRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemImageRead")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemImageWrite +__urdlllocal ur_result_t UR_APICALL urEnqueueMemImageWrite( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hImage, ///< [in][bounds(origin, region)] handle of the image object + bool + blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t + origin, ///< [in] defines the (x,y,z) offset in pixels in the 1D, 2D, or 3D image + ur_rect_region_t + region, ///< [in] defines the (width, height, depth) in pixels of the 1D, 2D, or 3D + ///< image + size_t rowPitch, ///< [in] length of each row in bytes + size_t slicePitch, ///< [in] length of each 2D slice of the 3D image + void *pSrc, ///< [in] pointer to host memory where image is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemImageWrite = context.urDdiTable.Enqueue.pfnMemImageWrite; + + if (nullptr == pfnMemImageWrite) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_image_write_params_t params = { + &hQueue, &hImage, &blockingWrite, + &origin, ®ion, &rowPitch, + &slicePitch, &pSrc, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemImageWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemImageWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemImageWrite")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemImageCopy +__urdlllocal ur_result_t UR_APICALL urEnqueueMemImageCopy( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hImageSrc, ///< [in][bounds(srcOrigin, region)] handle of the src image object + ur_mem_handle_t + hImageDst, ///< [in][bounds(dstOrigin, region)] handle of the dest image object + ur_rect_offset_t + srcOrigin, ///< [in] defines the (x,y,z) offset in pixels in the source 1D, 2D, or 3D + ///< image + ur_rect_offset_t + dstOrigin, ///< [in] defines the (x,y,z) offset in pixels in the destination 1D, 2D, + ///< or 3D image + ur_rect_region_t + region, ///< [in] defines the (width, height, depth) in pixels of the 1D, 2D, or 3D + ///< image + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemImageCopy = context.urDdiTable.Enqueue.pfnMemImageCopy; + + if (nullptr == pfnMemImageCopy) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_image_copy_params_t params = { + &hQueue, &hImageSrc, &hImageDst, &srcOrigin, &dstOrigin, + ®ion, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemImageCopy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemImageCopy")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemImageCopy")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferMap +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool blockingMap, ///< [in] indicates blocking (true), non-blocking (false) + ur_map_flags_t mapFlags, ///< [in] flags for read, write, readwrite mapping + size_t offset, ///< [in] offset in bytes of the buffer region being mapped + size_t size, ///< [in] size in bytes of the buffer region being mapped + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent, ///< [out][optional] return an event object that identifies this particular + ///< command instance. + void **ppRetMap ///< [out] return mapped pointer. TODO: move it before + ///< numEventsInWaitList? +) { + auto pfnMemBufferMap = context.urDdiTable.Enqueue.pfnMemBufferMap; + + if (nullptr == pfnMemBufferMap) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_buffer_map_params_t params = { + &hQueue, &hBuffer, &blockingMap, &mapFlags, + &offset, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent, &ppRetMap}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemBufferMap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemBufferMap")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemBufferMap")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemUnmap +__urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hMem, ///< [in] handle of the memory (buffer or image) object + void *pMappedPtr, ///< [in] mapped host address + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemUnmap = context.urDdiTable.Enqueue.pfnMemUnmap; + + if (nullptr == pfnMemUnmap) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_mem_unmap_params_t params = { + &hQueue, &hMem, &pMappedPtr, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueMemUnmap")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueMemUnmap")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueMemUnmap")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMFill +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + void *pMem, ///< [in][bounds(0, size)] pointer to USM memory object + size_t + patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less + ///< than or equal to width. + const void + *pPattern, ///< [in] pointer with the bytes of the pattern to set. + size_t + size, ///< [in] size in bytes to be set. Must be a multiple of patternSize. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnUSMFill = context.urDdiTable.Enqueue.pfnUSMFill; + + if (nullptr == pfnUSMFill) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_fill_params_t params = { + &hQueue, &pMem, &patternSize, + &pPattern, &size, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMFill")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMFill")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMFill")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMMemcpy +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + bool blocking, ///< [in] blocking or non-blocking copy + void * + pDst, ///< [in][bounds(0, size)] pointer to the destination USM memory object + const void * + pSrc, ///< [in][bounds(0, size)] pointer to the source USM memory object + size_t size, ///< [in] size in bytes to be copied + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnUSMMemcpy = context.urDdiTable.Enqueue.pfnUSMMemcpy; + + if (nullptr == pfnUSMMemcpy) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_memcpy_params_t params = { + &hQueue, &blocking, &pDst, &pSrc, &size, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMMemcpy")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMMemcpy")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMMemcpy")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMPrefetch +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMPrefetch( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + const void + *pMem, ///< [in][bounds(0, size)] pointer to the USM memory object + size_t size, ///< [in] size in bytes to be fetched + ur_usm_migration_flags_t flags, ///< [in] USM prefetch flags + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnUSMPrefetch = context.urDdiTable.Enqueue.pfnUSMPrefetch; + + if (nullptr == pfnUSMPrefetch) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_prefetch_params_t params = { + &hQueue, &pMem, &size, &flags, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMPrefetch")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMPrefetch")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMPrefetch")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMAdvise +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMAdvise( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + const void + *pMem, ///< [in][bounds(0, size)] pointer to the USM memory object + size_t size, ///< [in] size in bytes to be advised + ur_usm_advice_flags_t advice, ///< [in] USM memory advice + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnUSMAdvise = context.urDdiTable.Enqueue.pfnUSMAdvise; + + if (nullptr == pfnUSMAdvise) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_advise_params_t params = {&hQueue, &pMem, &size, &advice, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMAdvise")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMAdvise")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMAdvise")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMFill2D +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill2D( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + void * + pMem, ///< [in][bounds(0, pitch * height)] pointer to memory to be filled. + size_t + pitch, ///< [in] the total width of the destination memory including padding. + size_t + patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less + ///< than or equal to width. + const void + *pPattern, ///< [in] pointer with the bytes of the pattern to set. + size_t + width, ///< [in] the width in bytes of each row to fill. Must be a multiple of + ///< patternSize. + size_t height, ///< [in] the height of the columns to fill. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnUSMFill2D = context.urDdiTable.Enqueue.pfnUSMFill2D; + + if (nullptr == pfnUSMFill2D) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_fill_2d_params_t params = { + &hQueue, &pMem, &pitch, &patternSize, + &pPattern, &width, &height, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMFill2D")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMFill2D")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMFill2D")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMMemcpy2D +__urdlllocal ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + bool blocking, ///< [in] indicates if this operation should block the host. + void * + pDst, ///< [in][bounds(0, dstPitch * height)] pointer to memory where data will + ///< be copied. + size_t + dstPitch, ///< [in] the total width of the source memory including padding. + const void * + pSrc, ///< [in][bounds(0, srcPitch * height)] pointer to memory to be copied. + size_t + srcPitch, ///< [in] the total width of the source memory including padding. + size_t width, ///< [in] the width in bytes of each row to be copied. + size_t height, ///< [in] the height of columns to be copied. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnUSMMemcpy2D = context.urDdiTable.Enqueue.pfnUSMMemcpy2D; + + if (nullptr == pfnUSMMemcpy2D) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_usm_memcpy_2d_params_t params = { + &hQueue, &blocking, &pDst, + &dstPitch, &pSrc, &srcPitch, + &width, &height, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueUSMMemcpy2D")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueUSMMemcpy2D")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueUSMMemcpy2D")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueDeviceGlobalVariableWrite +__urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + ur_program_handle_t + hProgram, ///< [in] handle of the program containing the device global variable. + const char + *name, ///< [in] the unique identifier for the device global variable. + bool blockingWrite, ///< [in] indicates if this operation should block. + size_t count, ///< [in] the number of bytes to copy. + size_t + offset, ///< [in] the byte offset into the device global variable to start copying. + const void *pSrc, ///< [in] pointer to where the data must be copied from. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list. + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnDeviceGlobalVariableWrite = + context.urDdiTable.Enqueue.pfnDeviceGlobalVariableWrite; + + if (nullptr == pfnDeviceGlobalVariableWrite) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_device_global_variable_write_params_t params = { + &hQueue, &hProgram, &name, &blockingWrite, + &count, &offset, &pSrc, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueDeviceGlobalVariableWrite")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueDeviceGlobalVariableRead +__urdlllocal ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + ur_program_handle_t + hProgram, ///< [in] handle of the program containing the device global variable. + const char + *name, ///< [in] the unique identifier for the device global variable. + bool blockingRead, ///< [in] indicates if this operation should block. + size_t count, ///< [in] the number of bytes to copy. + size_t + offset, ///< [in] the byte offset into the device global variable to start copying. + void *pDst, ///< [in] pointer to where the data must be copied to. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list. + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnDeviceGlobalVariableRead = + context.urDdiTable.Enqueue.pfnDeviceGlobalVariableRead; + + if (nullptr == pfnDeviceGlobalVariableRead) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_device_global_variable_read_params_t params = { + &hQueue, &hProgram, &name, &blockingRead, + &count, &offset, &pDst, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueDeviceGlobalVariableRead")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueReadHostPipe +__urdlllocal ur_result_t UR_APICALL urEnqueueReadHostPipe( + ur_queue_handle_t + hQueue, ///< [in] a valid host command-queue in which the read command + ///< will be queued. hQueue and hProgram must be created with the same + ///< UR context. + ur_program_handle_t + hProgram, ///< [in] a program object with a successfully built executable. + const char * + pipe_symbol, ///< [in] the name of the program scope pipe global variable. + bool + blocking, ///< [in] indicate if the read operation is blocking or non-blocking. + void * + pDst, ///< [in] a pointer to buffer in host memory that will hold resulting data + ///< from pipe. + size_t size, ///< [in] size of the memory region to read, in bytes. + uint32_t numEventsInWaitList, ///< [in] number of events in the wait list. + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the host pipe read. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event. + ur_event_handle_t * + phEvent ///< [out][optional] returns an event object that identifies this read + ///< command + ///< and can be used to query or queue a wait for this command to complete. +) { + auto pfnReadHostPipe = context.urDdiTable.Enqueue.pfnReadHostPipe; + + if (nullptr == pfnReadHostPipe) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_read_host_pipe_params_t params = { + &hQueue, &hProgram, &pipe_symbol, &blocking, + &pDst, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueReadHostPipe")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueReadHostPipe")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueReadHostPipe")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueWriteHostPipe +__urdlllocal ur_result_t UR_APICALL urEnqueueWriteHostPipe( + ur_queue_handle_t + hQueue, ///< [in] a valid host command-queue in which the write command + ///< will be queued. hQueue and hProgram must be created with the same + ///< UR context. + ur_program_handle_t + hProgram, ///< [in] a program object with a successfully built executable. + const char * + pipe_symbol, ///< [in] the name of the program scope pipe global variable. + bool + blocking, ///< [in] indicate if the read and write operations are blocking or + ///< non-blocking. + void * + pSrc, ///< [in] a pointer to buffer in host memory that holds data to be written + ///< to the host pipe. + size_t size, ///< [in] size of the memory region to read or write, in bytes. + uint32_t numEventsInWaitList, ///< [in] number of events in the wait list. + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the host pipe write. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event. + ur_event_handle_t * + phEvent ///< [out][optional] returns an event object that identifies this write command + ///< and can be used to query or queue a wait for this command to complete. +) { + auto pfnWriteHostPipe = context.urDdiTable.Enqueue.pfnWriteHostPipe; + + if (nullptr == pfnWriteHostPipe) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_write_host_pipe_params_t params = { + &hQueue, &hProgram, &pipe_symbol, &blocking, + &pSrc, &size, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urEnqueueWriteHostPipe")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urEnqueueWriteHostPipe")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urEnqueueWriteHostPipe")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMPitchedAllocExp +__urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t * + pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor. + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + widthInBytes, ///< [in] width in bytes of the USM memory object to be allocated + size_t height, ///< [in] height of the USM memory object to be allocated + size_t + elementSizeBytes, ///< [in] size in bytes of an element in the allocation + void **ppMem, ///< [out] pointer to USM shared memory object + size_t *pResultPitch ///< [out] pitch of the allocation +) { + auto pfnPitchedAllocExp = context.urDdiTable.USMExp.pfnPitchedAllocExp; + + if (nullptr == pfnPitchedAllocExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_pitched_alloc_exp_params_t params = { + &hContext, &hDevice, &pUSMDesc, &pool, &widthInBytes, + &height, &elementSizeBytes, &ppMem, &pResultPitch}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMPitchedAllocExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMPitchedAllocExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMPitchedAllocExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesUnsampledImageHandleDestroyExp +__urdlllocal ur_result_t UR_APICALL +urBindlessImagesUnsampledImageHandleDestroyExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_handle_t + hImage ///< [in] pointer to handle of image object to destroy +) { + auto pfnUnsampledImageHandleDestroyExp = + context.urDdiTable.BindlessImagesExp.pfnUnsampledImageHandleDestroyExp; + + if (nullptr == pfnUnsampledImageHandleDestroyExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_unsampled_image_handle_destroy_exp_params_t params = { + &hContext, &hDevice, &hImage}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesUnsampledImageHandleDestroyExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesSampledImageHandleDestroyExp +__urdlllocal ur_result_t UR_APICALL +urBindlessImagesSampledImageHandleDestroyExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_handle_t + hImage ///< [in] pointer to handle of image object to destroy +) { + auto pfnSampledImageHandleDestroyExp = + context.urDdiTable.BindlessImagesExp.pfnSampledImageHandleDestroyExp; + + if (nullptr == pfnSampledImageHandleDestroyExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_sampled_image_handle_destroy_exp_params_t params = { + &hContext, &hDevice, &hImage}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesSampledImageHandleDestroyExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImageAllocateExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesImageAllocateExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + ur_exp_image_mem_handle_t + *phImageMem ///< [out] pointer to handle of image memory allocated +) { + auto pfnImageAllocateExp = + context.urDdiTable.BindlessImagesExp.pfnImageAllocateExp; + + if (nullptr == pfnImageAllocateExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_image_allocate_exp_params_t params = { + &hContext, &hDevice, &pImageFormat, &pImageDesc, &phImageMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImageAllocateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImageAllocateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phImageMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImageAllocateExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImageFreeExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesImageFreeExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_mem_handle_t + hImageMem ///< [in] handle of image memory to be freed +) { + auto pfnImageFreeExp = context.urDdiTable.BindlessImagesExp.pfnImageFreeExp; + + if (nullptr == pfnImageFreeExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_image_free_exp_params_t params = {&hContext, &hDevice, + &hImageMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImageFreeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImageFreeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImageFreeExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesUnsampledImageCreateExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesUnsampledImageCreateExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_mem_handle_t + hImageMem, ///< [in] handle to memory from which to create the image + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + ur_exp_image_handle_t + *phImage ///< [out] pointer to handle of image object created +) { + auto pfnUnsampledImageCreateExp = + context.urDdiTable.BindlessImagesExp.pfnUnsampledImageCreateExp; + + if (nullptr == pfnUnsampledImageCreateExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_unsampled_image_create_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &pImageFormat, &pImageDesc, &phImage}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phImage = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesUnsampledImageCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesSampledImageCreateExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesSampledImageCreateExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_mem_handle_t + hImageMem, ///< [in] handle to memory from which to create the image + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + ur_sampler_handle_t hSampler, ///< [in] sampler to be used + ur_exp_image_handle_t + *phImage ///< [out] pointer to handle of image object created +) { + auto pfnSampledImageCreateExp = + context.urDdiTable.BindlessImagesExp.pfnSampledImageCreateExp; + + if (nullptr == pfnSampledImageCreateExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_sampled_image_create_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &pImageFormat, + &pImageDesc, &hSampler, &phImage}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesSampledImageCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesSampledImageCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phImage = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesSampledImageCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImageCopyExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesImageCopyExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + void *pDst, ///< [in] location the data will be copied to + void *pSrc, ///< [in] location the data will be copied from + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + ur_exp_image_copy_flags_t + imageCopyFlags, ///< [in] flags describing copy direction e.g. H2D or D2H + ur_rect_offset_t + srcOffset, ///< [in] defines the (x,y,z) source offset in pixels in the 1D, 2D, or 3D + ///< image + ur_rect_offset_t + dstOffset, ///< [in] defines the (x,y,z) destination offset in pixels in the 1D, 2D, + ///< or 3D image + ur_rect_region_t + copyExtent, ///< [in] defines the (width, height, depth) in pixels of the 1D, 2D, or 3D + ///< region to copy + ur_rect_region_t + hostExtent, ///< [in] defines the (width, height, depth) in pixels of the 1D, 2D, or 3D + ///< region on the host + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that all + ///< previously enqueued commands + ///< must be complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnImageCopyExp = context.urDdiTable.BindlessImagesExp.pfnImageCopyExp; + + if (nullptr == pfnImageCopyExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_image_copy_exp_params_t params = {&hQueue, + &pDst, + &pSrc, + &pImageFormat, + &pImageDesc, + &imageCopyFlags, + &srcOffset, + &dstOffset, + ©Extent, + &hostExtent, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImageCopyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImageCopyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImageCopyExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImageGetInfoExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesImageGetInfoExp( + ur_exp_image_mem_handle_t hImageMem, ///< [in] handle to the image memory + ur_image_info_t propName, ///< [in] queried info name + void *pPropValue, ///< [out][optional] returned query value + size_t *pPropSizeRet ///< [out][optional] returned query value size +) { + auto pfnImageGetInfoExp = + context.urDdiTable.BindlessImagesExp.pfnImageGetInfoExp; + + if (nullptr == pfnImageGetInfoExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_image_get_info_exp_params_t params = { + &hImageMem, &propName, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImageGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImageGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImageGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesMipmapGetLevelExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapGetLevelExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_mem_handle_t + hImageMem, ///< [in] memory handle to the mipmap image + uint32_t mipmapLevel, ///< [in] requested level of the mipmap + ur_exp_image_mem_handle_t + *phImageMem ///< [out] returning memory handle to the individual image +) { + auto pfnMipmapGetLevelExp = + context.urDdiTable.BindlessImagesExp.pfnMipmapGetLevelExp; + + if (nullptr == pfnMipmapGetLevelExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_mipmap_get_level_exp_params_t params = { + &hContext, &hDevice, &hImageMem, &mipmapLevel, &phImageMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phImageMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesMipmapGetLevelExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesMipmapFreeExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesMipmapFreeExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_image_mem_handle_t hMem ///< [in] handle of image memory to be freed +) { + auto pfnMipmapFreeExp = + context.urDdiTable.BindlessImagesExp.pfnMipmapFreeExp; + + if (nullptr == pfnMipmapFreeExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_mipmap_free_exp_params_t params = {&hContext, &hDevice, + &hMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesMipmapFreeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesMipmapFreeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesMipmapFreeExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImportOpaqueFDExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesImportOpaqueFDExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + size_t size, ///< [in] size of the external memory + ur_exp_interop_mem_desc_t + *pInteropMemDesc, ///< [in] the interop memory descriptor + ur_exp_interop_mem_handle_t + *phInteropMem ///< [out] interop memory handle to the external memory +) { + auto pfnImportOpaqueFDExp = + context.urDdiTable.BindlessImagesExp.pfnImportOpaqueFDExp; + + if (nullptr == pfnImportOpaqueFDExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_import_opaque_fd_exp_params_t params = { + &hContext, &hDevice, &size, &pInteropMemDesc, &phInteropMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImportOpaqueFDExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImportOpaqueFDExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phInteropMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImportOpaqueFDExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesMapExternalArrayExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesMapExternalArrayExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_image_format_t + *pImageFormat, ///< [in] pointer to image format specification + const ur_image_desc_t *pImageDesc, ///< [in] pointer to image description + ur_exp_interop_mem_handle_t + hInteropMem, ///< [in] interop memory handle to the external memory + ur_exp_image_mem_handle_t * + phImageMem ///< [out] image memory handle to the externally allocated memory +) { + auto pfnMapExternalArrayExp = + context.urDdiTable.BindlessImagesExp.pfnMapExternalArrayExp; + + if (nullptr == pfnMapExternalArrayExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_map_external_array_exp_params_t params = { + &hContext, &hDevice, &pImageFormat, + &pImageDesc, &hInteropMem, &phImageMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesMapExternalArrayExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesMapExternalArrayExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phImageMem = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesMapExternalArrayExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesReleaseInteropExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_interop_mem_handle_t + hInteropMem ///< [in][release] handle of interop memory to be freed +) { + auto pfnReleaseInteropExp = + context.urDdiTable.BindlessImagesExp.pfnReleaseInteropExp; + + if (nullptr == pfnReleaseInteropExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_release_interop_exp_params_t params = { + &hContext, &hDevice, &hInteropMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesReleaseInteropExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesReleaseInteropExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hInteropMem); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesReleaseInteropExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesImportExternalSemaphoreOpaqueFDExp +__urdlllocal ur_result_t UR_APICALL +urBindlessImagesImportExternalSemaphoreOpaqueFDExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_interop_semaphore_desc_t + *pInteropSemaphoreDesc, ///< [in] the interop semaphore descriptor + ur_exp_interop_semaphore_handle_t * + phInteropSemaphore ///< [out] interop semaphore handle to the external semaphore +) { + auto pfnImportExternalSemaphoreOpaqueFDExp = + context.urDdiTable.BindlessImagesExp + .pfnImportExternalSemaphoreOpaqueFDExp; + + if (nullptr == pfnImportExternalSemaphoreOpaqueFDExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_import_external_semaphore_opaque_fd_exp_params_t params = + {&hContext, &hDevice, &pInteropSemaphoreDesc, &phInteropSemaphore}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesImportExternalSemaphoreOpaqueFDExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesImportExternalSemaphoreOpaqueFDExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phInteropSemaphore = + createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesImportExternalSemaphoreOpaqueFDExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesDestroyExternalSemaphoreExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesDestroyExternalSemaphoreExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + ur_exp_interop_semaphore_handle_t + hInteropSemaphore ///< [in] handle of interop semaphore to be destroyed +) { + auto pfnDestroyExternalSemaphoreExp = + context.urDdiTable.BindlessImagesExp.pfnDestroyExternalSemaphoreExp; + + if (nullptr == pfnDestroyExternalSemaphoreExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_destroy_external_semaphore_exp_params_t params = { + &hContext, &hDevice, &hInteropSemaphore}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesDestroyExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesWaitExternalSemaphoreExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesWaitExternalSemaphoreExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_exp_interop_semaphore_handle_t + hSemaphore, ///< [in] interop semaphore handle + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that all + ///< previously enqueued commands + ///< must be complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnWaitExternalSemaphoreExp = + context.urDdiTable.BindlessImagesExp.pfnWaitExternalSemaphoreExp; + + if (nullptr == pfnWaitExternalSemaphoreExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_wait_external_semaphore_exp_params_t params = { + &hQueue, &hSemaphore, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesWaitExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urBindlessImagesSignalExternalSemaphoreExp +__urdlllocal ur_result_t UR_APICALL urBindlessImagesSignalExternalSemaphoreExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_exp_interop_semaphore_handle_t + hSemaphore, ///< [in] interop semaphore handle + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that all + ///< previously enqueued commands + ///< must be complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnSignalExternalSemaphoreExp = + context.urDdiTable.BindlessImagesExp.pfnSignalExternalSemaphoreExp; + + if (nullptr == pfnSignalExternalSemaphoreExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_bindless_images_signal_external_semaphore_exp_params_t params = { + &hQueue, &hSemaphore, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urBindlessImagesSignalExternalSemaphoreExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferCreateExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( + ur_context_handle_t hContext, ///< [in] Handle of the context object. + ur_device_handle_t hDevice, ///< [in] Handle of the device object. + const ur_exp_command_buffer_desc_t + *pCommandBufferDesc, ///< [in][optional] command-buffer descriptor. + ur_exp_command_buffer_handle_t + *phCommandBuffer ///< [out] Pointer to command-Buffer handle. +) { + auto pfnCreateExp = context.urDdiTable.CommandBufferExp.pfnCreateExp; + + if (nullptr == pfnCreateExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_create_exp_params_t params = { + &hContext, &hDevice, &pCommandBufferDesc, &phCommandBuffer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferCreateExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urCommandBufferCreateExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phCommandBuffer = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferCreateExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferRetainExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( + ur_exp_command_buffer_handle_t + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. +) { + auto pfnRetainExp = context.urDdiTable.CommandBufferExp.pfnRetainExp; + + if (nullptr == pfnRetainExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_retain_exp_params_t params = {&hCommandBuffer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferRetainExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urCommandBufferRetainExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + retainDummyHandle(hCommandBuffer); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferRetainExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferReleaseExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( + ur_exp_command_buffer_handle_t + hCommandBuffer ///< [in][release] Handle of the command-buffer object. +) { + auto pfnReleaseExp = context.urDdiTable.CommandBufferExp.pfnReleaseExp; + + if (nullptr == pfnReleaseExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_release_exp_params_t params = {&hCommandBuffer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferReleaseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urCommandBufferReleaseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hCommandBuffer); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferReleaseExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferFinalizeExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferFinalizeExp( + ur_exp_command_buffer_handle_t + hCommandBuffer ///< [in] Handle of the command-buffer object. +) { + auto pfnFinalizeExp = context.urDdiTable.CommandBufferExp.pfnFinalizeExp; + + if (nullptr == pfnFinalizeExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_finalize_exp_params_t params = {&hCommandBuffer}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferFinalizeExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferFinalizeExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferFinalizeExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_kernel_handle_t hKernel, ///< [in] Kernel to append. + uint32_t workDim, ///< [in] Dimension of the kernel execution. + const size_t + *pGlobalWorkOffset, ///< [in] Offset to use when executing kernel. + const size_t * + pGlobalWorkSize, ///< [in] Global work size to use when executing kernel. + const size_t * + pLocalWorkSize, ///< [in][optional] Local work size to use when executing kernel. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint, ///< [out][optional] Sync point associated with this command. + ur_exp_command_buffer_command_handle_t + *phCommand ///< [out][optional] Handle to this command. +) { + auto pfnAppendKernelLaunchExp = + context.urDdiTable.CommandBufferExp.pfnAppendKernelLaunchExp; + + if (nullptr == pfnAppendKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_kernel_launch_exp_params_t params = { + &hCommandBuffer, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint, + &phCommand}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phCommand) { + *phCommand = + createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendUSMMemcpyExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + void *pDst, ///< [in] Location the data will be copied to. + const void *pSrc, ///< [in] The data to be copied. + size_t size, ///< [in] The number of bytes to copy. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendUSMMemcpyExp = + context.urDdiTable.CommandBufferExp.pfnAppendUSMMemcpyExp; + + if (nullptr == pfnAppendUSMMemcpyExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_usm_memcpy_exp_params_t params = { + &hCommandBuffer, &pDst, &pSrc, &size, &numSyncPointsInWaitList, + &pSyncPointWaitList, &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendUSMMemcpyExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendUSMFillExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] handle of the command-buffer object. + void *pMemory, ///< [in] pointer to USM allocated memory to fill. + const void *pPattern, ///< [in] pointer to the fill pattern. + size_t patternSize, ///< [in] size in bytes of the pattern. + size_t + size, ///< [in] fill size in bytes, must be a multiple of patternSize. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] sync point associated with this command. +) { + auto pfnAppendUSMFillExp = + context.urDdiTable.CommandBufferExp.pfnAppendUSMFillExp; + + if (nullptr == pfnAppendUSMFillExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_usm_fill_exp_params_t params = { + &hCommandBuffer, &pMemory, &pPattern, + &patternSize, &size, &numSyncPointsInWaitList, + &pSyncPointWaitList, &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendUSMFillExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendUSMFillExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendUSMFillExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferCopyExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hSrcMem, ///< [in] The data to be copied. + ur_mem_handle_t hDstMem, ///< [in] The location the data will be copied to. + size_t srcOffset, ///< [in] Offset into the source memory. + size_t dstOffset, ///< [in] Offset into the destination memory + size_t size, ///< [in] The number of bytes to be copied. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferCopyExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferCopyExp; + + if (nullptr == pfnAppendMemBufferCopyExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_copy_exp_params_t params = { + &hCommandBuffer, + &hSrcMem, + &hDstMem, + &srcOffset, + &dstOffset, + &size, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferCopyExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferWriteExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hBuffer, ///< [in] Handle of the buffer object. + size_t offset, ///< [in] Offset in bytes in the buffer object. + size_t size, ///< [in] Size in bytes of data being written. + const void * + pSrc, ///< [in] Pointer to host memory where data is to be written from. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferWriteExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferWriteExp; + + if (nullptr == pfnAppendMemBufferWriteExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_write_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &offset, + &size, + &pSrc, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferWriteExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferReadExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hBuffer, ///< [in] Handle of the buffer object. + size_t offset, ///< [in] Offset in bytes in the buffer object. + size_t size, ///< [in] Size in bytes of data being written. + void *pDst, ///< [in] Pointer to host memory where data is to be written to. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferReadExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferReadExp; + + if (nullptr == pfnAppendMemBufferReadExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_read_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &offset, + &size, + &pDst, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferReadExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferCopyRectExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hSrcMem, ///< [in] The data to be copied. + ur_mem_handle_t hDstMem, ///< [in] The location the data will be copied to. + ur_rect_offset_t + srcOrigin, ///< [in] Origin for the region of data to be copied from the source. + ur_rect_offset_t + dstOrigin, ///< [in] Origin for the region of data to be copied to in the destination. + ur_rect_region_t + region, ///< [in] The extents describing the region to be copied. + size_t srcRowPitch, ///< [in] Row pitch of the source memory. + size_t srcSlicePitch, ///< [in] Slice pitch of the source memory. + size_t dstRowPitch, ///< [in] Row pitch of the destination memory. + size_t dstSlicePitch, ///< [in] Slice pitch of the destination memory. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferCopyRectExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferCopyRectExp; + + if (nullptr == pfnAppendMemBufferCopyRectExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_copy_rect_exp_params_t params = { + &hCommandBuffer, + &hSrcMem, + &hDstMem, + &srcOrigin, + &dstOrigin, + ®ion, + &srcRowPitch, + &srcSlicePitch, + &dstRowPitch, + &dstSlicePitch, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferCopyRectExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferWriteRectExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hBuffer, ///< [in] Handle of the buffer object. + ur_rect_offset_t bufferOffset, ///< [in] 3D offset in the buffer. + ur_rect_offset_t hostOffset, ///< [in] 3D offset in the host region. + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth. + size_t + bufferRowPitch, ///< [in] Length of each row in bytes in the buffer object. + size_t + bufferSlicePitch, ///< [in] Length of each 2D slice in bytes in the buffer object being + ///< written. + size_t + hostRowPitch, ///< [in] Length of each row in bytes in the host memory region pointed to + ///< by pSrc. + size_t + hostSlicePitch, ///< [in] Length of each 2D slice in bytes in the host memory region + ///< pointed to by pSrc. + void * + pSrc, ///< [in] Pointer to host memory where data is to be written from. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferWriteRectExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferWriteRectExp; + + if (nullptr == pfnAppendMemBufferWriteRectExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_write_rect_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &bufferOffset, + &hostOffset, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pSrc, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferWriteRectExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferReadRectExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_mem_handle_t hBuffer, ///< [in] Handle of the buffer object. + ur_rect_offset_t bufferOffset, ///< [in] 3D offset in the buffer. + ur_rect_offset_t hostOffset, ///< [in] 3D offset in the host region. + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth. + size_t + bufferRowPitch, ///< [in] Length of each row in bytes in the buffer object. + size_t + bufferSlicePitch, ///< [in] Length of each 2D slice in bytes in the buffer object being read. + size_t + hostRowPitch, ///< [in] Length of each row in bytes in the host memory region pointed to + ///< by pDst. + size_t + hostSlicePitch, ///< [in] Length of each 2D slice in bytes in the host memory region + ///< pointed to by pDst. + void *pDst, ///< [in] Pointer to host memory where data is to be read into. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] Sync point associated with this command. +) { + auto pfnAppendMemBufferReadRectExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferReadRectExp; + + if (nullptr == pfnAppendMemBufferReadRectExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_read_rect_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &bufferOffset, + &hostOffset, + ®ion, + &bufferRowPitch, + &bufferSlicePitch, + &hostRowPitch, + &hostSlicePitch, + &pDst, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferReadRectExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendMemBufferFillExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] handle of the command-buffer object. + ur_mem_handle_t hBuffer, ///< [in] handle of the buffer object. + const void *pPattern, ///< [in] pointer to the fill pattern. + size_t patternSize, ///< [in] size in bytes of the pattern. + size_t offset, ///< [in] offset into the buffer. + size_t + size, ///< [in] fill size in bytes, must be a multiple of patternSize. + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] sync point associated with this command. +) { + auto pfnAppendMemBufferFillExp = + context.urDdiTable.CommandBufferExp.pfnAppendMemBufferFillExp; + + if (nullptr == pfnAppendMemBufferFillExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_mem_buffer_fill_exp_params_t params = { + &hCommandBuffer, + &hBuffer, + &pPattern, + &patternSize, + &offset, + &size, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendMemBufferFillExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendUSMPrefetchExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] handle of the command-buffer object. + const void *pMemory, ///< [in] pointer to USM allocated memory to prefetch. + size_t size, ///< [in] size in bytes to be fetched. + ur_usm_migration_flags_t flags, ///< [in] USM prefetch flags + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] sync point associated with this command. +) { + auto pfnAppendUSMPrefetchExp = + context.urDdiTable.CommandBufferExp.pfnAppendUSMPrefetchExp; + + if (nullptr == pfnAppendUSMPrefetchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_usm_prefetch_exp_params_t params = { + &hCommandBuffer, + &pMemory, + &size, + &flags, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendUSMPrefetchExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferAppendUSMAdviseExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] handle of the command-buffer object. + const void *pMemory, ///< [in] pointer to the USM memory object. + size_t size, ///< [in] size in bytes to be advised. + ur_usm_advice_flags_t advice, ///< [in] USM memory advice + uint32_t + numSyncPointsInWaitList, ///< [in] The number of sync points in the provided dependency list. + const ur_exp_command_buffer_sync_point_t * + pSyncPointWaitList, ///< [in][optional] A list of sync points that this command depends on. May + ///< be ignored if command-buffer is in-order. + ur_exp_command_buffer_sync_point_t * + pSyncPoint ///< [out][optional] sync point associated with this command. +) { + auto pfnAppendUSMAdviseExp = + context.urDdiTable.CommandBufferExp.pfnAppendUSMAdviseExp; + + if (nullptr == pfnAppendUSMAdviseExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_append_usm_advise_exp_params_t params = { + &hCommandBuffer, + &pMemory, + &size, + &advice, + &numSyncPointsInWaitList, + &pSyncPointWaitList, + &pSyncPoint}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferAppendUSMAdviseExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferEnqueueExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] Handle of the command-buffer object. + ur_queue_handle_t + hQueue, ///< [in] The queue to submit this command-buffer for execution. + uint32_t numEventsInWaitList, ///< [in] Size of the event wait list. + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the command-buffer execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating no wait events. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command-buffer execution instance. +) { + auto pfnEnqueueExp = context.urDdiTable.CommandBufferExp.pfnEnqueueExp; + + if (nullptr == pfnEnqueueExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_enqueue_exp_params_t params = { + &hCommandBuffer, &hQueue, &numEventsInWaitList, &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferEnqueueExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urCommandBufferEnqueueExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferEnqueueExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferRetainCommandExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( + ur_exp_command_buffer_command_handle_t + hCommand ///< [in] Handle of the command-buffer command. +) { + auto pfnRetainCommandExp = + context.urDdiTable.CommandBufferExp.pfnRetainCommandExp; + + if (nullptr == pfnRetainCommandExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_retain_command_exp_params_t params = {&hCommand}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferRetainCommandExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferRetainCommandExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferRetainCommandExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferReleaseCommandExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( + ur_exp_command_buffer_command_handle_t + hCommand ///< [in][release] Handle of the command-buffer command. +) { + auto pfnReleaseCommandExp = + context.urDdiTable.CommandBufferExp.pfnReleaseCommandExp; + + if (nullptr == pfnReleaseCommandExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_release_command_exp_params_t params = {&hCommand}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferReleaseCommandExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferReleaseCommandExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + releaseDummyHandle(hCommand); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferReleaseCommandExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferUpdateKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( + ur_exp_command_buffer_command_handle_t + hCommand, ///< [in] Handle of the command-buffer kernel command to update. + const ur_exp_command_buffer_update_kernel_launch_desc_t * + pUpdateKernelLaunch ///< [in] Struct defining how the kernel command is to be updated. +) { + auto pfnUpdateKernelLaunchExp = + context.urDdiTable.CommandBufferExp.pfnUpdateKernelLaunchExp; + + if (nullptr == pfnUpdateKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_update_kernel_launch_exp_params_t params = { + &hCommand, &pUpdateKernelLaunch}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferUpdateKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferGetInfoExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferGetInfoExp( + ur_exp_command_buffer_handle_t + hCommandBuffer, ///< [in] handle of the command-buffer object + ur_exp_command_buffer_info_t + propName, ///< [in] the name of the command-buffer property to query + size_t + propSize, ///< [in] size in bytes of the command-buffer property value + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the + ///< command-buffer property + size_t * + pPropSizeRet ///< [out][optional] bytes returned in command-buffer property +) { + auto pfnGetInfoExp = context.urDdiTable.CommandBufferExp.pfnGetInfoExp; + + if (nullptr == pfnGetInfoExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_get_info_exp_params_t params = { + &hCommandBuffer, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urCommandBufferGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urCommandBufferGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urCommandBufferGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urCommandBufferCommandGetInfoExp +__urdlllocal ur_result_t UR_APICALL urCommandBufferCommandGetInfoExp( + ur_exp_command_buffer_command_handle_t + hCommand, ///< [in] handle of the command-buffer command object + ur_exp_command_buffer_command_info_t + propName, ///< [in] the name of the command-buffer command property to query + size_t + propSize, ///< [in] size in bytes of the command-buffer command property value + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] value of the + ///< command-buffer command property + size_t * + pPropSizeRet ///< [out][optional] bytes returned in command-buffer command property +) { + auto pfnCommandGetInfoExp = + context.urDdiTable.CommandBufferExp.pfnCommandGetInfoExp; + + if (nullptr == pfnCommandGetInfoExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_command_buffer_command_get_info_exp_params_t params = { + &hCommand, &propName, &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urCommandBufferCommandGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urCommandBufferCommandGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urCommandBufferCommandGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnCooperativeKernelLaunchExp = + context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + + if (nullptr == pfnCooperativeKernelLaunchExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_cooperative_kernel_launch_exp_params_t params = { + &hQueue, + &hKernel, + &workDim, + &pGlobalWorkOffset, + &pGlobalWorkSize, + &pLocalWorkSize, + &numEventsInWaitList, + &phEventWaitList, + &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + // optional output handle + if (phEvent) { + *phEvent = createDummyHandle(); + } + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueCooperativeKernelLaunchExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + size_t + localWorkSize, ///< [in] number of local work-items that will form a work-group when the + ///< kernel is launched + size_t + dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes, + ///< that will be used when the kernel is launched + uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups +) { + auto pfnSuggestMaxCooperativeGroupCountExp = + context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = { + &hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urKernelSuggestMaxCooperativeGroupCountExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueTimestampRecordingExp +__urdlllocal ur_result_t UR_APICALL urEnqueueTimestampRecordingExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + bool + blocking, ///< [in] indicates whether the call to this function should block until + ///< until the device timestamp recording command has executed on the + ///< device. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating no wait + ///< events. + ur_event_handle_t * + phEvent ///< [in,out] return an event object that identifies this particular kernel + ///< execution instance. Profiling information can be queried + ///< from this event as if `hQueue` had profiling enabled. Querying + ///< `UR_PROFILING_INFO_COMMAND_QUEUED` or `UR_PROFILING_INFO_COMMAND_SUBMIT` + ///< reports the timestamp at the time of the call to this function. + ///< Querying `UR_PROFILING_INFO_COMMAND_START` or `UR_PROFILING_INFO_COMMAND_END` + ///< reports the timestamp recorded when the command is executed on the device. +) { + auto pfnTimestampRecordingExp = + context.urDdiTable.EnqueueExp.pfnTimestampRecordingExp; + + if (nullptr == pfnTimestampRecordingExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_timestamp_recording_exp_params_t params = { + &hQueue, &blocking, &numEventsInWaitList, &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueTimestampRecordingExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueTimestampRecordingExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phEvent = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueTimestampRecordingExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueKernelLaunchCustomExp +__urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. If nullptr, the runtime implementation + ///< will choose the work-group size. + uint32_t numPropsInLaunchPropList, ///< [in] size of the launch prop list + const ur_exp_launch_property_t * + launchPropList, ///< [in][range(0, numPropsInLaunchPropList)] pointer to a list of launch + ///< properties + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. If nullptr, + ///< the numEventsInWaitList must be 0, indicating that no wait event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnKernelLaunchCustomExp = + context.urDdiTable.EnqueueExp.pfnKernelLaunchCustomExp; + + if (nullptr == pfnKernelLaunchCustomExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_enqueue_kernel_launch_custom_exp_params_t params = { + &hQueue, &hKernel, + &workDim, &pGlobalWorkSize, + &pLocalWorkSize, &numPropsInLaunchPropList, + &launchPropList, &numEventsInWaitList, + &phEventWaitList, &phEvent}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urEnqueueKernelLaunchCustomExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urEnqueueKernelLaunchCustomExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urEnqueueKernelLaunchCustomExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramBuildExp +__urdlllocal ur_result_t UR_APICALL urProgramBuildExp( + ur_program_handle_t hProgram, ///< [in] Handle of the program to build. + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles + const char * + pOptions ///< [in][optional] pointer to build options null-terminated string. +) { + auto pfnBuildExp = context.urDdiTable.ProgramExp.pfnBuildExp; + + if (nullptr == pfnBuildExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_build_exp_params_t params = {&hProgram, &numDevices, &phDevices, + &pOptions}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramBuildExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramBuildExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramBuildExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCompileExp +__urdlllocal ur_result_t UR_APICALL urProgramCompileExp( + ur_program_handle_t + hProgram, ///< [in][out] handle of the program to compile. + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles + const char * + pOptions ///< [in][optional] pointer to build options null-terminated string. +) { + auto pfnCompileExp = context.urDdiTable.ProgramExp.pfnCompileExp; + + if (nullptr == pfnCompileExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_compile_exp_params_t params = {&hProgram, &numDevices, + &phDevices, &pOptions}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramCompileExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramCompileExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramCompileExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramLinkExp +__urdlllocal ur_result_t UR_APICALL urProgramLinkExp( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles + uint32_t count, ///< [in] number of program handles in `phPrograms`. + const ur_program_handle_t * + phPrograms, ///< [in][range(0, count)] pointer to array of program handles. + const char * + pOptions, ///< [in][optional] pointer to linker options null-terminated string. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnLinkExp = context.urDdiTable.ProgramExp.pfnLinkExp; + + if (nullptr == pfnLinkExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_program_link_exp_params_t params = {&hContext, &numDevices, &phDevices, + &count, &phPrograms, &pOptions, + &phProgram}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urProgramLinkExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urProgramLinkExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + *phProgram = createDummyHandle(); + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urProgramLinkExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMImportExp +__urdlllocal ur_result_t UR_APICALL urUSMImportExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + void *pMem, ///< [in] pointer to host memory object + size_t size ///< [in] size in bytes of the host memory object to be imported +) { + auto pfnImportExp = context.urDdiTable.USMExp.pfnImportExp; + + if (nullptr == pfnImportExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_import_exp_params_t params = {&hContext, &pMem, &size}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMImportExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMImportExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMImportExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMReleaseExp +__urdlllocal ur_result_t UR_APICALL urUSMReleaseExp( + ur_context_handle_t hContext, ///< [in] handle of the context object + void *pMem ///< [in] pointer to host memory object +) { + auto pfnReleaseExp = context.urDdiTable.USMExp.pfnReleaseExp; + + if (nullptr == pfnReleaseExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_release_exp_params_t params = {&hContext, &pMem}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback("urUSMReleaseExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback("urUSMReleaseExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUSMReleaseExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUsmP2PEnablePeerAccessExp +__urdlllocal ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp( + ur_device_handle_t + commandDevice, ///< [in] handle of the command device object + ur_device_handle_t peerDevice ///< [in] handle of the peer device object +) { + auto pfnEnablePeerAccessExp = + context.urDdiTable.UsmP2PExp.pfnEnablePeerAccessExp; + + if (nullptr == pfnEnablePeerAccessExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_p2p_enable_peer_access_exp_params_t params = {&commandDevice, + &peerDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urUsmP2PEnablePeerAccessExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urUsmP2PEnablePeerAccessExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback("urUsmP2PEnablePeerAccessExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUsmP2PDisablePeerAccessExp +__urdlllocal ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp( + ur_device_handle_t + commandDevice, ///< [in] handle of the command device object + ur_device_handle_t peerDevice ///< [in] handle of the peer device object +) { + auto pfnDisablePeerAccessExp = + context.urDdiTable.UsmP2PExp.pfnDisablePeerAccessExp; + + if (nullptr == pfnDisablePeerAccessExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_p2p_disable_peer_access_exp_params_t params = {&commandDevice, + &peerDevice}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urUsmP2PDisablePeerAccessExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urUsmP2PDisablePeerAccessExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urUsmP2PDisablePeerAccessExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUsmP2PPeerAccessGetInfoExp +__urdlllocal ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp( + ur_device_handle_t + commandDevice, ///< [in] handle of the command device object + ur_device_handle_t peerDevice, ///< [in] handle of the peer device object + ur_exp_peer_info_t propName, ///< [in] type of the info to retrieve + size_t propSize, ///< [in] the number of bytes pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If propSize is not equal to or greater than the real number of bytes + ///< needed to return the info + ///< then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnPeerAccessGetInfoExp = + context.urDdiTable.UsmP2PExp.pfnPeerAccessGetInfoExp; + + if (nullptr == pfnPeerAccessGetInfoExp) { + return UR_RESULT_ERROR_UNINITIALIZED; + } + + ur_usm_p2p_peer_access_get_info_exp_params_t params = { + &commandDevice, &peerDevice, &propName, + &propSize, &pPropValue, &pPropSizeRet}; + + ur_result_t result = UR_RESULT_SUCCESS; + + auto beforeCallback = reinterpret_cast( + context.apiCallbacks.get_before_callback( + "urUsmP2PPeerAccessGetInfoExp")); + if (beforeCallback) { + result = beforeCallback(¶ms); + if (result != UR_RESULT_SUCCESS) { + return result; + } + } + + auto replaceCallback = reinterpret_cast( + context.apiCallbacks.get_replace_callback( + "urUsmP2PPeerAccessGetInfoExp")); + if (replaceCallback) { + result = replaceCallback(¶ms); + } else { + + result = UR_RESULT_SUCCESS; + } + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto afterCallback = reinterpret_cast( + context.apiCallbacks.get_after_callback( + "urUsmP2PPeerAccessGetInfoExp")); + if (afterCallback) { + return afterCallback(¶ms); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Global table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_global_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Global; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnAdapterGet = pDdiTable->pfnAdapterGet; + pDdiTable->pfnAdapterGet = ur_mock_layer::urAdapterGet; + + dditable.pfnAdapterRelease = pDdiTable->pfnAdapterRelease; + pDdiTable->pfnAdapterRelease = ur_mock_layer::urAdapterRelease; + + dditable.pfnAdapterRetain = pDdiTable->pfnAdapterRetain; + pDdiTable->pfnAdapterRetain = ur_mock_layer::urAdapterRetain; + + dditable.pfnAdapterGetLastError = pDdiTable->pfnAdapterGetLastError; + pDdiTable->pfnAdapterGetLastError = ur_mock_layer::urAdapterGetLastError; + + dditable.pfnAdapterGetInfo = pDdiTable->pfnAdapterGetInfo; + pDdiTable->pfnAdapterGetInfo = ur_mock_layer::urAdapterGetInfo; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's BindlessImagesExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetBindlessImagesExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_bindless_images_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.BindlessImagesExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnUnsampledImageHandleDestroyExp = + pDdiTable->pfnUnsampledImageHandleDestroyExp; + pDdiTable->pfnUnsampledImageHandleDestroyExp = + ur_mock_layer::urBindlessImagesUnsampledImageHandleDestroyExp; + + dditable.pfnSampledImageHandleDestroyExp = + pDdiTable->pfnSampledImageHandleDestroyExp; + pDdiTable->pfnSampledImageHandleDestroyExp = + ur_mock_layer::urBindlessImagesSampledImageHandleDestroyExp; + + dditable.pfnImageAllocateExp = pDdiTable->pfnImageAllocateExp; + pDdiTable->pfnImageAllocateExp = + ur_mock_layer::urBindlessImagesImageAllocateExp; + + dditable.pfnImageFreeExp = pDdiTable->pfnImageFreeExp; + pDdiTable->pfnImageFreeExp = ur_mock_layer::urBindlessImagesImageFreeExp; + + dditable.pfnUnsampledImageCreateExp = pDdiTable->pfnUnsampledImageCreateExp; + pDdiTable->pfnUnsampledImageCreateExp = + ur_mock_layer::urBindlessImagesUnsampledImageCreateExp; + + dditable.pfnSampledImageCreateExp = pDdiTable->pfnSampledImageCreateExp; + pDdiTable->pfnSampledImageCreateExp = + ur_mock_layer::urBindlessImagesSampledImageCreateExp; + + dditable.pfnImageCopyExp = pDdiTable->pfnImageCopyExp; + pDdiTable->pfnImageCopyExp = ur_mock_layer::urBindlessImagesImageCopyExp; + + dditable.pfnImageGetInfoExp = pDdiTable->pfnImageGetInfoExp; + pDdiTable->pfnImageGetInfoExp = + ur_mock_layer::urBindlessImagesImageGetInfoExp; + + dditable.pfnMipmapGetLevelExp = pDdiTable->pfnMipmapGetLevelExp; + pDdiTable->pfnMipmapGetLevelExp = + ur_mock_layer::urBindlessImagesMipmapGetLevelExp; + + dditable.pfnMipmapFreeExp = pDdiTable->pfnMipmapFreeExp; + pDdiTable->pfnMipmapFreeExp = ur_mock_layer::urBindlessImagesMipmapFreeExp; + + dditable.pfnImportOpaqueFDExp = pDdiTable->pfnImportOpaqueFDExp; + pDdiTable->pfnImportOpaqueFDExp = + ur_mock_layer::urBindlessImagesImportOpaqueFDExp; + + dditable.pfnMapExternalArrayExp = pDdiTable->pfnMapExternalArrayExp; + pDdiTable->pfnMapExternalArrayExp = + ur_mock_layer::urBindlessImagesMapExternalArrayExp; + + dditable.pfnReleaseInteropExp = pDdiTable->pfnReleaseInteropExp; + pDdiTable->pfnReleaseInteropExp = + ur_mock_layer::urBindlessImagesReleaseInteropExp; + + dditable.pfnImportExternalSemaphoreOpaqueFDExp = + pDdiTable->pfnImportExternalSemaphoreOpaqueFDExp; + pDdiTable->pfnImportExternalSemaphoreOpaqueFDExp = + ur_mock_layer::urBindlessImagesImportExternalSemaphoreOpaqueFDExp; + + dditable.pfnDestroyExternalSemaphoreExp = + pDdiTable->pfnDestroyExternalSemaphoreExp; + pDdiTable->pfnDestroyExternalSemaphoreExp = + ur_mock_layer::urBindlessImagesDestroyExternalSemaphoreExp; + + dditable.pfnWaitExternalSemaphoreExp = + pDdiTable->pfnWaitExternalSemaphoreExp; + pDdiTable->pfnWaitExternalSemaphoreExp = + ur_mock_layer::urBindlessImagesWaitExternalSemaphoreExp; + + dditable.pfnSignalExternalSemaphoreExp = + pDdiTable->pfnSignalExternalSemaphoreExp; + pDdiTable->pfnSignalExternalSemaphoreExp = + ur_mock_layer::urBindlessImagesSignalExternalSemaphoreExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's CommandBufferExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_command_buffer_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.CommandBufferExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreateExp = pDdiTable->pfnCreateExp; + pDdiTable->pfnCreateExp = ur_mock_layer::urCommandBufferCreateExp; + + dditable.pfnRetainExp = pDdiTable->pfnRetainExp; + pDdiTable->pfnRetainExp = ur_mock_layer::urCommandBufferRetainExp; + + dditable.pfnReleaseExp = pDdiTable->pfnReleaseExp; + pDdiTable->pfnReleaseExp = ur_mock_layer::urCommandBufferReleaseExp; + + dditable.pfnFinalizeExp = pDdiTable->pfnFinalizeExp; + pDdiTable->pfnFinalizeExp = ur_mock_layer::urCommandBufferFinalizeExp; + + dditable.pfnAppendKernelLaunchExp = pDdiTable->pfnAppendKernelLaunchExp; + pDdiTable->pfnAppendKernelLaunchExp = + ur_mock_layer::urCommandBufferAppendKernelLaunchExp; + + dditable.pfnAppendUSMMemcpyExp = pDdiTable->pfnAppendUSMMemcpyExp; + pDdiTable->pfnAppendUSMMemcpyExp = + ur_mock_layer::urCommandBufferAppendUSMMemcpyExp; + + dditable.pfnAppendUSMFillExp = pDdiTable->pfnAppendUSMFillExp; + pDdiTable->pfnAppendUSMFillExp = + ur_mock_layer::urCommandBufferAppendUSMFillExp; + + dditable.pfnAppendMemBufferCopyExp = pDdiTable->pfnAppendMemBufferCopyExp; + pDdiTable->pfnAppendMemBufferCopyExp = + ur_mock_layer::urCommandBufferAppendMemBufferCopyExp; + + dditable.pfnAppendMemBufferWriteExp = pDdiTable->pfnAppendMemBufferWriteExp; + pDdiTable->pfnAppendMemBufferWriteExp = + ur_mock_layer::urCommandBufferAppendMemBufferWriteExp; + + dditable.pfnAppendMemBufferReadExp = pDdiTable->pfnAppendMemBufferReadExp; + pDdiTable->pfnAppendMemBufferReadExp = + ur_mock_layer::urCommandBufferAppendMemBufferReadExp; + + dditable.pfnAppendMemBufferCopyRectExp = + pDdiTable->pfnAppendMemBufferCopyRectExp; + pDdiTable->pfnAppendMemBufferCopyRectExp = + ur_mock_layer::urCommandBufferAppendMemBufferCopyRectExp; + + dditable.pfnAppendMemBufferWriteRectExp = + pDdiTable->pfnAppendMemBufferWriteRectExp; + pDdiTable->pfnAppendMemBufferWriteRectExp = + ur_mock_layer::urCommandBufferAppendMemBufferWriteRectExp; + + dditable.pfnAppendMemBufferReadRectExp = + pDdiTable->pfnAppendMemBufferReadRectExp; + pDdiTable->pfnAppendMemBufferReadRectExp = + ur_mock_layer::urCommandBufferAppendMemBufferReadRectExp; + + dditable.pfnAppendMemBufferFillExp = pDdiTable->pfnAppendMemBufferFillExp; + pDdiTable->pfnAppendMemBufferFillExp = + ur_mock_layer::urCommandBufferAppendMemBufferFillExp; + + dditable.pfnAppendUSMPrefetchExp = pDdiTable->pfnAppendUSMPrefetchExp; + pDdiTable->pfnAppendUSMPrefetchExp = + ur_mock_layer::urCommandBufferAppendUSMPrefetchExp; + + dditable.pfnAppendUSMAdviseExp = pDdiTable->pfnAppendUSMAdviseExp; + pDdiTable->pfnAppendUSMAdviseExp = + ur_mock_layer::urCommandBufferAppendUSMAdviseExp; + + dditable.pfnEnqueueExp = pDdiTable->pfnEnqueueExp; + pDdiTable->pfnEnqueueExp = ur_mock_layer::urCommandBufferEnqueueExp; + + dditable.pfnRetainCommandExp = pDdiTable->pfnRetainCommandExp; + pDdiTable->pfnRetainCommandExp = + ur_mock_layer::urCommandBufferRetainCommandExp; + + dditable.pfnReleaseCommandExp = pDdiTable->pfnReleaseCommandExp; + pDdiTable->pfnReleaseCommandExp = + ur_mock_layer::urCommandBufferReleaseCommandExp; + + dditable.pfnUpdateKernelLaunchExp = pDdiTable->pfnUpdateKernelLaunchExp; + pDdiTable->pfnUpdateKernelLaunchExp = + ur_mock_layer::urCommandBufferUpdateKernelLaunchExp; + + dditable.pfnGetInfoExp = pDdiTable->pfnGetInfoExp; + pDdiTable->pfnGetInfoExp = ur_mock_layer::urCommandBufferGetInfoExp; + + dditable.pfnCommandGetInfoExp = pDdiTable->pfnCommandGetInfoExp; + pDdiTable->pfnCommandGetInfoExp = + ur_mock_layer::urCommandBufferCommandGetInfoExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Context table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetContextProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_context_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Context; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreate = pDdiTable->pfnCreate; + pDdiTable->pfnCreate = ur_mock_layer::urContextCreate; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urContextRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urContextRelease; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urContextGetInfo; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urContextGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urContextCreateWithNativeHandle; + + dditable.pfnSetExtendedDeleter = pDdiTable->pfnSetExtendedDeleter; + pDdiTable->pfnSetExtendedDeleter = + ur_mock_layer::urContextSetExtendedDeleter; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Enqueue table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Enqueue; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnKernelLaunch = pDdiTable->pfnKernelLaunch; + pDdiTable->pfnKernelLaunch = ur_mock_layer::urEnqueueKernelLaunch; + + dditable.pfnEventsWait = pDdiTable->pfnEventsWait; + pDdiTable->pfnEventsWait = ur_mock_layer::urEnqueueEventsWait; + + dditable.pfnEventsWaitWithBarrier = pDdiTable->pfnEventsWaitWithBarrier; + pDdiTable->pfnEventsWaitWithBarrier = + ur_mock_layer::urEnqueueEventsWaitWithBarrier; + + dditable.pfnMemBufferRead = pDdiTable->pfnMemBufferRead; + pDdiTable->pfnMemBufferRead = ur_mock_layer::urEnqueueMemBufferRead; + + dditable.pfnMemBufferWrite = pDdiTable->pfnMemBufferWrite; + pDdiTable->pfnMemBufferWrite = ur_mock_layer::urEnqueueMemBufferWrite; + + dditable.pfnMemBufferReadRect = pDdiTable->pfnMemBufferReadRect; + pDdiTable->pfnMemBufferReadRect = ur_mock_layer::urEnqueueMemBufferReadRect; + + dditable.pfnMemBufferWriteRect = pDdiTable->pfnMemBufferWriteRect; + pDdiTable->pfnMemBufferWriteRect = + ur_mock_layer::urEnqueueMemBufferWriteRect; + + dditable.pfnMemBufferCopy = pDdiTable->pfnMemBufferCopy; + pDdiTable->pfnMemBufferCopy = ur_mock_layer::urEnqueueMemBufferCopy; + + dditable.pfnMemBufferCopyRect = pDdiTable->pfnMemBufferCopyRect; + pDdiTable->pfnMemBufferCopyRect = ur_mock_layer::urEnqueueMemBufferCopyRect; + + dditable.pfnMemBufferFill = pDdiTable->pfnMemBufferFill; + pDdiTable->pfnMemBufferFill = ur_mock_layer::urEnqueueMemBufferFill; + + dditable.pfnMemImageRead = pDdiTable->pfnMemImageRead; + pDdiTable->pfnMemImageRead = ur_mock_layer::urEnqueueMemImageRead; + + dditable.pfnMemImageWrite = pDdiTable->pfnMemImageWrite; + pDdiTable->pfnMemImageWrite = ur_mock_layer::urEnqueueMemImageWrite; + + dditable.pfnMemImageCopy = pDdiTable->pfnMemImageCopy; + pDdiTable->pfnMemImageCopy = ur_mock_layer::urEnqueueMemImageCopy; + + dditable.pfnMemBufferMap = pDdiTable->pfnMemBufferMap; + pDdiTable->pfnMemBufferMap = ur_mock_layer::urEnqueueMemBufferMap; + + dditable.pfnMemUnmap = pDdiTable->pfnMemUnmap; + pDdiTable->pfnMemUnmap = ur_mock_layer::urEnqueueMemUnmap; + + dditable.pfnUSMFill = pDdiTable->pfnUSMFill; + pDdiTable->pfnUSMFill = ur_mock_layer::urEnqueueUSMFill; + + dditable.pfnUSMMemcpy = pDdiTable->pfnUSMMemcpy; + pDdiTable->pfnUSMMemcpy = ur_mock_layer::urEnqueueUSMMemcpy; + + dditable.pfnUSMPrefetch = pDdiTable->pfnUSMPrefetch; + pDdiTable->pfnUSMPrefetch = ur_mock_layer::urEnqueueUSMPrefetch; + + dditable.pfnUSMAdvise = pDdiTable->pfnUSMAdvise; + pDdiTable->pfnUSMAdvise = ur_mock_layer::urEnqueueUSMAdvise; + + dditable.pfnUSMFill2D = pDdiTable->pfnUSMFill2D; + pDdiTable->pfnUSMFill2D = ur_mock_layer::urEnqueueUSMFill2D; + + dditable.pfnUSMMemcpy2D = pDdiTable->pfnUSMMemcpy2D; + pDdiTable->pfnUSMMemcpy2D = ur_mock_layer::urEnqueueUSMMemcpy2D; + + dditable.pfnDeviceGlobalVariableWrite = + pDdiTable->pfnDeviceGlobalVariableWrite; + pDdiTable->pfnDeviceGlobalVariableWrite = + ur_mock_layer::urEnqueueDeviceGlobalVariableWrite; + + dditable.pfnDeviceGlobalVariableRead = + pDdiTable->pfnDeviceGlobalVariableRead; + pDdiTable->pfnDeviceGlobalVariableRead = + ur_mock_layer::urEnqueueDeviceGlobalVariableRead; + + dditable.pfnReadHostPipe = pDdiTable->pfnReadHostPipe; + pDdiTable->pfnReadHostPipe = ur_mock_layer::urEnqueueReadHostPipe; + + dditable.pfnWriteHostPipe = pDdiTable->pfnWriteHostPipe; + pDdiTable->pfnWriteHostPipe = ur_mock_layer::urEnqueueWriteHostPipe; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.EnqueueExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnKernelLaunchCustomExp = pDdiTable->pfnKernelLaunchCustomExp; + pDdiTable->pfnKernelLaunchCustomExp = + ur_mock_layer::urEnqueueKernelLaunchCustomExp; + + dditable.pfnCooperativeKernelLaunchExp = + pDdiTable->pfnCooperativeKernelLaunchExp; + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_mock_layer::urEnqueueCooperativeKernelLaunchExp; + + dditable.pfnTimestampRecordingExp = pDdiTable->pfnTimestampRecordingExp; + pDdiTable->pfnTimestampRecordingExp = + ur_mock_layer::urEnqueueTimestampRecordingExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Event table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEventProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_event_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Event; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urEventGetInfo; + + dditable.pfnGetProfilingInfo = pDdiTable->pfnGetProfilingInfo; + pDdiTable->pfnGetProfilingInfo = ur_mock_layer::urEventGetProfilingInfo; + + dditable.pfnWait = pDdiTable->pfnWait; + pDdiTable->pfnWait = ur_mock_layer::urEventWait; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urEventRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urEventRelease; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urEventGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urEventCreateWithNativeHandle; + + dditable.pfnSetCallback = pDdiTable->pfnSetCallback; + pDdiTable->pfnSetCallback = ur_mock_layer::urEventSetCallback; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Kernel table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Kernel; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreate = pDdiTable->pfnCreate; + pDdiTable->pfnCreate = ur_mock_layer::urKernelCreate; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urKernelGetInfo; + + dditable.pfnGetGroupInfo = pDdiTable->pfnGetGroupInfo; + pDdiTable->pfnGetGroupInfo = ur_mock_layer::urKernelGetGroupInfo; + + dditable.pfnGetSubGroupInfo = pDdiTable->pfnGetSubGroupInfo; + pDdiTable->pfnGetSubGroupInfo = ur_mock_layer::urKernelGetSubGroupInfo; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urKernelRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urKernelRelease; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urKernelGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urKernelCreateWithNativeHandle; + + dditable.pfnGetSuggestedLocalWorkSize = + pDdiTable->pfnGetSuggestedLocalWorkSize; + pDdiTable->pfnGetSuggestedLocalWorkSize = + ur_mock_layer::urKernelGetSuggestedLocalWorkSize; + + dditable.pfnSetArgValue = pDdiTable->pfnSetArgValue; + pDdiTable->pfnSetArgValue = ur_mock_layer::urKernelSetArgValue; + + dditable.pfnSetArgLocal = pDdiTable->pfnSetArgLocal; + pDdiTable->pfnSetArgLocal = ur_mock_layer::urKernelSetArgLocal; + + dditable.pfnSetArgPointer = pDdiTable->pfnSetArgPointer; + pDdiTable->pfnSetArgPointer = ur_mock_layer::urKernelSetArgPointer; + + dditable.pfnSetExecInfo = pDdiTable->pfnSetExecInfo; + pDdiTable->pfnSetExecInfo = ur_mock_layer::urKernelSetExecInfo; + + dditable.pfnSetArgSampler = pDdiTable->pfnSetArgSampler; + pDdiTable->pfnSetArgSampler = ur_mock_layer::urKernelSetArgSampler; + + dditable.pfnSetArgMemObj = pDdiTable->pfnSetArgMemObj; + pDdiTable->pfnSetArgMemObj = ur_mock_layer::urKernelSetArgMemObj; + + dditable.pfnSetSpecializationConstants = + pDdiTable->pfnSetSpecializationConstants; + pDdiTable->pfnSetSpecializationConstants = + ur_mock_layer::urKernelSetSpecializationConstants; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.KernelExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnSuggestMaxCooperativeGroupCountExp = + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp; + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_mock_layer::urKernelSuggestMaxCooperativeGroupCountExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Mem table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetMemProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_mem_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Mem; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnImageCreate = pDdiTable->pfnImageCreate; + pDdiTable->pfnImageCreate = ur_mock_layer::urMemImageCreate; + + dditable.pfnBufferCreate = pDdiTable->pfnBufferCreate; + pDdiTable->pfnBufferCreate = ur_mock_layer::urMemBufferCreate; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urMemRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urMemRelease; + + dditable.pfnBufferPartition = pDdiTable->pfnBufferPartition; + pDdiTable->pfnBufferPartition = ur_mock_layer::urMemBufferPartition; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urMemGetNativeHandle; + + dditable.pfnBufferCreateWithNativeHandle = + pDdiTable->pfnBufferCreateWithNativeHandle; + pDdiTable->pfnBufferCreateWithNativeHandle = + ur_mock_layer::urMemBufferCreateWithNativeHandle; + + dditable.pfnImageCreateWithNativeHandle = + pDdiTable->pfnImageCreateWithNativeHandle; + pDdiTable->pfnImageCreateWithNativeHandle = + ur_mock_layer::urMemImageCreateWithNativeHandle; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urMemGetInfo; + + dditable.pfnImageGetInfo = pDdiTable->pfnImageGetInfo; + pDdiTable->pfnImageGetInfo = ur_mock_layer::urMemImageGetInfo; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's PhysicalMem table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetPhysicalMemProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_physical_mem_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.PhysicalMem; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreate = pDdiTable->pfnCreate; + pDdiTable->pfnCreate = ur_mock_layer::urPhysicalMemCreate; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urPhysicalMemRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urPhysicalMemRelease; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Platform table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetPlatformProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_platform_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Platform; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnGet = pDdiTable->pfnGet; + pDdiTable->pfnGet = ur_mock_layer::urPlatformGet; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urPlatformGetInfo; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urPlatformGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urPlatformCreateWithNativeHandle; + + dditable.pfnGetApiVersion = pDdiTable->pfnGetApiVersion; + pDdiTable->pfnGetApiVersion = ur_mock_layer::urPlatformGetApiVersion; + + dditable.pfnGetBackendOption = pDdiTable->pfnGetBackendOption; + pDdiTable->pfnGetBackendOption = ur_mock_layer::urPlatformGetBackendOption; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Program table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_program_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Program; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreateWithIL = pDdiTable->pfnCreateWithIL; + pDdiTable->pfnCreateWithIL = ur_mock_layer::urProgramCreateWithIL; + + dditable.pfnCreateWithBinary = pDdiTable->pfnCreateWithBinary; + pDdiTable->pfnCreateWithBinary = ur_mock_layer::urProgramCreateWithBinary; + + dditable.pfnBuild = pDdiTable->pfnBuild; + pDdiTable->pfnBuild = ur_mock_layer::urProgramBuild; + + dditable.pfnCompile = pDdiTable->pfnCompile; + pDdiTable->pfnCompile = ur_mock_layer::urProgramCompile; + + dditable.pfnLink = pDdiTable->pfnLink; + pDdiTable->pfnLink = ur_mock_layer::urProgramLink; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urProgramRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urProgramRelease; + + dditable.pfnGetFunctionPointer = pDdiTable->pfnGetFunctionPointer; + pDdiTable->pfnGetFunctionPointer = + ur_mock_layer::urProgramGetFunctionPointer; + + dditable.pfnGetGlobalVariablePointer = + pDdiTable->pfnGetGlobalVariablePointer; + pDdiTable->pfnGetGlobalVariablePointer = + ur_mock_layer::urProgramGetGlobalVariablePointer; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urProgramGetInfo; + + dditable.pfnGetBuildInfo = pDdiTable->pfnGetBuildInfo; + pDdiTable->pfnGetBuildInfo = ur_mock_layer::urProgramGetBuildInfo; + + dditable.pfnSetSpecializationConstants = + pDdiTable->pfnSetSpecializationConstants; + pDdiTable->pfnSetSpecializationConstants = + ur_mock_layer::urProgramSetSpecializationConstants; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urProgramGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urProgramCreateWithNativeHandle; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's ProgramExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_program_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.ProgramExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnBuildExp = pDdiTable->pfnBuildExp; + pDdiTable->pfnBuildExp = ur_mock_layer::urProgramBuildExp; + + dditable.pfnCompileExp = pDdiTable->pfnCompileExp; + pDdiTable->pfnCompileExp = ur_mock_layer::urProgramCompileExp; + + dditable.pfnLinkExp = pDdiTable->pfnLinkExp; + pDdiTable->pfnLinkExp = ur_mock_layer::urProgramLinkExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Queue table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_queue_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Queue; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urQueueGetInfo; + + dditable.pfnCreate = pDdiTable->pfnCreate; + pDdiTable->pfnCreate = ur_mock_layer::urQueueCreate; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urQueueRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urQueueRelease; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urQueueGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urQueueCreateWithNativeHandle; + + dditable.pfnFinish = pDdiTable->pfnFinish; + pDdiTable->pfnFinish = ur_mock_layer::urQueueFinish; + + dditable.pfnFlush = pDdiTable->pfnFlush; + pDdiTable->pfnFlush = ur_mock_layer::urQueueFlush; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Sampler table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetSamplerProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_sampler_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Sampler; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCreate = pDdiTable->pfnCreate; + pDdiTable->pfnCreate = ur_mock_layer::urSamplerCreate; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urSamplerRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urSamplerRelease; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urSamplerGetInfo; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urSamplerGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urSamplerCreateWithNativeHandle; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's USM table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetUSMProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_usm_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.USM; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnHostAlloc = pDdiTable->pfnHostAlloc; + pDdiTable->pfnHostAlloc = ur_mock_layer::urUSMHostAlloc; + + dditable.pfnDeviceAlloc = pDdiTable->pfnDeviceAlloc; + pDdiTable->pfnDeviceAlloc = ur_mock_layer::urUSMDeviceAlloc; + + dditable.pfnSharedAlloc = pDdiTable->pfnSharedAlloc; + pDdiTable->pfnSharedAlloc = ur_mock_layer::urUSMSharedAlloc; + + dditable.pfnFree = pDdiTable->pfnFree; + pDdiTable->pfnFree = ur_mock_layer::urUSMFree; + + dditable.pfnGetMemAllocInfo = pDdiTable->pfnGetMemAllocInfo; + pDdiTable->pfnGetMemAllocInfo = ur_mock_layer::urUSMGetMemAllocInfo; + + dditable.pfnPoolCreate = pDdiTable->pfnPoolCreate; + pDdiTable->pfnPoolCreate = ur_mock_layer::urUSMPoolCreate; + + dditable.pfnPoolRetain = pDdiTable->pfnPoolRetain; + pDdiTable->pfnPoolRetain = ur_mock_layer::urUSMPoolRetain; + + dditable.pfnPoolRelease = pDdiTable->pfnPoolRelease; + pDdiTable->pfnPoolRelease = ur_mock_layer::urUSMPoolRelease; + + dditable.pfnPoolGetInfo = pDdiTable->pfnPoolGetInfo; + pDdiTable->pfnPoolGetInfo = ur_mock_layer::urUSMPoolGetInfo; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's USMExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetUSMExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_usm_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.USMExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnPitchedAllocExp = pDdiTable->pfnPitchedAllocExp; + pDdiTable->pfnPitchedAllocExp = ur_mock_layer::urUSMPitchedAllocExp; + + dditable.pfnImportExp = pDdiTable->pfnImportExp; + pDdiTable->pfnImportExp = ur_mock_layer::urUSMImportExp; + + dditable.pfnReleaseExp = pDdiTable->pfnReleaseExp; + pDdiTable->pfnReleaseExp = ur_mock_layer::urUSMReleaseExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's UsmP2PExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetUsmP2PExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_usm_p2p_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.UsmP2PExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnEnablePeerAccessExp = pDdiTable->pfnEnablePeerAccessExp; + pDdiTable->pfnEnablePeerAccessExp = + ur_mock_layer::urUsmP2PEnablePeerAccessExp; + + dditable.pfnDisablePeerAccessExp = pDdiTable->pfnDisablePeerAccessExp; + pDdiTable->pfnDisablePeerAccessExp = + ur_mock_layer::urUsmP2PDisablePeerAccessExp; + + dditable.pfnPeerAccessGetInfoExp = pDdiTable->pfnPeerAccessGetInfoExp; + pDdiTable->pfnPeerAccessGetInfoExp = + ur_mock_layer::urUsmP2PPeerAccessGetInfoExp; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's VirtualMem table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetVirtualMemProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_virtual_mem_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.VirtualMem; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnGranularityGetInfo = pDdiTable->pfnGranularityGetInfo; + pDdiTable->pfnGranularityGetInfo = + ur_mock_layer::urVirtualMemGranularityGetInfo; + + dditable.pfnReserve = pDdiTable->pfnReserve; + pDdiTable->pfnReserve = ur_mock_layer::urVirtualMemReserve; + + dditable.pfnFree = pDdiTable->pfnFree; + pDdiTable->pfnFree = ur_mock_layer::urVirtualMemFree; + + dditable.pfnMap = pDdiTable->pfnMap; + pDdiTable->pfnMap = ur_mock_layer::urVirtualMemMap; + + dditable.pfnUnmap = pDdiTable->pfnUnmap; + pDdiTable->pfnUnmap = ur_mock_layer::urVirtualMemUnmap; + + dditable.pfnSetAccess = pDdiTable->pfnSetAccess; + pDdiTable->pfnSetAccess = ur_mock_layer::urVirtualMemSetAccess; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urVirtualMemGetInfo; + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Device table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetDeviceProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_device_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_mock_layer::context.urDdiTable.Device; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_mock_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_mock_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnGet = pDdiTable->pfnGet; + pDdiTable->pfnGet = ur_mock_layer::urDeviceGet; + + dditable.pfnGetInfo = pDdiTable->pfnGetInfo; + pDdiTable->pfnGetInfo = ur_mock_layer::urDeviceGetInfo; + + dditable.pfnRetain = pDdiTable->pfnRetain; + pDdiTable->pfnRetain = ur_mock_layer::urDeviceRetain; + + dditable.pfnRelease = pDdiTable->pfnRelease; + pDdiTable->pfnRelease = ur_mock_layer::urDeviceRelease; + + dditable.pfnPartition = pDdiTable->pfnPartition; + pDdiTable->pfnPartition = ur_mock_layer::urDevicePartition; + + dditable.pfnSelectBinary = pDdiTable->pfnSelectBinary; + pDdiTable->pfnSelectBinary = ur_mock_layer::urDeviceSelectBinary; + + dditable.pfnGetNativeHandle = pDdiTable->pfnGetNativeHandle; + pDdiTable->pfnGetNativeHandle = ur_mock_layer::urDeviceGetNativeHandle; + + dditable.pfnCreateWithNativeHandle = pDdiTable->pfnCreateWithNativeHandle; + pDdiTable->pfnCreateWithNativeHandle = + ur_mock_layer::urDeviceCreateWithNativeHandle; + + dditable.pfnGetGlobalTimestamps = pDdiTable->pfnGetGlobalTimestamps; + pDdiTable->pfnGetGlobalTimestamps = + ur_mock_layer::urDeviceGetGlobalTimestamps; + + return result; +} + +ur_result_t context_t::init(ur_dditable_t *dditable, + const std::set &enabledLayerNames, + codeloc_data, api_callbacks apiCallbacks) { + ur_result_t result = UR_RESULT_SUCCESS; + + if (!enabledLayerNames.count(name)) { + return result; + } + + ur_mock_layer::context.apiCallbacks = apiCallbacks; + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetGlobalProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Global); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetBindlessImagesExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->BindlessImagesExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetCommandBufferExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->CommandBufferExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetContextProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Context); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetEnqueueProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Enqueue); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetEnqueueExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->EnqueueExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetEventProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Event); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetKernelProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Kernel); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetKernelExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->KernelExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetMemProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Mem); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetPhysicalMemProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->PhysicalMem); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetPlatformProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Platform); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetProgramProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Program); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetProgramExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->ProgramExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetQueueProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Queue); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetSamplerProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Sampler); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetUSMProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->USM); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetUSMExpProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->USMExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetUsmP2PExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->UsmP2PExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetVirtualMemProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->VirtualMem); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_mock_layer::urGetDeviceProcAddrTable(UR_API_VERSION_CURRENT, + &dditable->Device); + } + + return result; +} + +} // namespace ur_mock_layer diff --git a/source/loader/layers/sanitizer/ur_sanddi.cpp b/source/loader/layers/sanitizer/ur_sanddi.cpp index e352ad69b8..95ca54c0fd 100644 --- a/source/loader/layers/sanitizer/ur_sanddi.cpp +++ b/source/loader/layers/sanitizer/ur_sanddi.cpp @@ -1507,7 +1507,8 @@ __urdlllocal ur_result_t UR_APICALL urGetUSMProcAddrTable( ur_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - [[maybe_unused]] codeloc_data codelocData) { + [[maybe_unused]] codeloc_data codelocData, + [[maybe_unused]] api_callbacks apiCallbacks) { ur_result_t result = UR_RESULT_SUCCESS; if (enabledLayerNames.count("UR_LAYER_ASAN")) { diff --git a/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp b/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp index 32f01103d5..9572767fa4 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp +++ b/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp @@ -46,7 +46,8 @@ class __urdlllocal context_t : public proxy_layer_context_t { } ur_result_t init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) override; + codeloc_data codelocData, + api_callbacks apiCallbacks) override; ur_result_t tearDown() override; }; diff --git a/source/loader/layers/tracing/ur_tracing_layer.hpp b/source/loader/layers/tracing/ur_tracing_layer.hpp index 84a109fb4f..dabb24d4fb 100644 --- a/source/loader/layers/tracing/ur_tracing_layer.hpp +++ b/source/loader/layers/tracing/ur_tracing_layer.hpp @@ -36,7 +36,7 @@ class __urdlllocal context_t : public proxy_layer_context_t { std::vector getNames() const override { return {name}; } ur_result_t init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) override; + codeloc_data codelocDatam, api_callbacks) override; ur_result_t tearDown() override { return UR_RESULT_SUCCESS; } uint64_t notify_begin(uint32_t id, const char *name, void *args); void notify_end(uint32_t id, const char *name, void *args, diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index da61c34992..a62bdd93b2 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -57,7 +57,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRelease __urdlllocal ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) { auto pfnAdapterRelease = context.urDdiTable.Global.pfnAdapterRelease; @@ -87,7 +87,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { auto pfnAdapterRetain = context.urDdiTable.Global.pfnAdapterRetain; @@ -516,7 +516,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { auto pfnRetain = context.urDdiTable.Device.pfnRetain; @@ -546,7 +546,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urDeviceRelease __urdlllocal ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) { auto pfnRelease = context.urDdiTable.Device.pfnRelease; @@ -813,7 +814,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { auto pfnRetain = context.urDdiTable.Context.pfnRetain; @@ -843,7 +844,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urContextRelease __urdlllocal ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) { auto pfnRelease = context.urDdiTable.Context.pfnRelease; @@ -1108,7 +1110,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { auto pfnRetain = context.urDdiTable.Mem.pfnRetain; @@ -1137,7 +1140,8 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRelease __urdlllocal ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) { auto pfnRelease = context.urDdiTable.Mem.pfnRelease; @@ -1446,7 +1450,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { auto pfnRetain = context.urDdiTable.Sampler.pfnRetain; @@ -1477,7 +1481,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( /// @brief Intercept function for urSamplerRelease __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) { auto pfnRelease = context.urDdiTable.Sampler.pfnRelease; @@ -1846,7 +1850,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { auto pfnPoolRetain = context.urDdiTable.USM.pfnPoolRetain; @@ -1876,7 +1880,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRelease __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) { auto pfnPoolRelease = context.urDdiTable.USM.pfnPoolRelease; @@ -2264,7 +2268,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { auto pfnRetain = context.urDdiTable.PhysicalMem.pfnRetain; @@ -2295,7 +2299,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( /// @brief Intercept function for urPhysicalMemRelease __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) { auto pfnRelease = context.urDdiTable.PhysicalMem.pfnRelease; @@ -2510,7 +2514,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { auto pfnRetain = context.urDdiTable.Program.pfnRetain; @@ -2540,7 +2545,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease __urdlllocal ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) { auto pfnRelease = context.urDdiTable.Program.pfnRelease; @@ -3100,7 +3106,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { auto pfnRetain = context.urDdiTable.Kernel.pfnRetain; @@ -3130,7 +3136,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRelease __urdlllocal ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) { auto pfnRelease = context.urDdiTable.Kernel.pfnRelease; @@ -3550,7 +3557,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { auto pfnRetain = context.urDdiTable.Queue.pfnRetain; @@ -3580,7 +3588,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRelease __urdlllocal ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) { auto pfnRelease = context.urDdiTable.Queue.pfnRelease; @@ -3858,7 +3867,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { auto pfnRetain = context.urDdiTable.Event.pfnRetain; @@ -3888,7 +3897,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRelease __urdlllocal ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) { auto pfnRelease = context.urDdiTable.Event.pfnRelease; @@ -5975,7 +5984,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) { auto pfnReleaseInteropExp = context.urDdiTable.BindlessImagesExp.pfnReleaseInteropExp; @@ -6231,7 +6240,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { auto pfnRetainExp = context.urDdiTable.CommandBufferExp.pfnRetainExp; @@ -6263,7 +6272,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( /// @brief Intercept function for urCommandBufferReleaseExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) { auto pfnReleaseExp = context.urDdiTable.CommandBufferExp.pfnReleaseExp; @@ -7129,7 +7138,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// @brief Intercept function for urCommandBufferReleaseCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) { auto pfnReleaseCommandExp = context.urDdiTable.CommandBufferExp.pfnReleaseCommandExp; @@ -9091,7 +9100,7 @@ __urdlllocal ur_result_t UR_APICALL urGetDeviceProcAddrTable( ur_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) { + codeloc_data codelocData, api_callbacks) { ur_result_t result = UR_RESULT_SUCCESS; if (!enabledLayerNames.count(name)) { diff --git a/source/loader/layers/ur_proxy_layer.hpp b/source/loader/layers/ur_proxy_layer.hpp index 2b710f3287..5e726a3e17 100644 --- a/source/loader/layers/ur_proxy_layer.hpp +++ b/source/loader/layers/ur_proxy_layer.hpp @@ -12,6 +12,7 @@ #ifndef UR_PROXY_LAYER_H #define UR_PROXY_LAYER_H 1 +#include "ur_callbacks.hpp" #include "ur_codeloc.hpp" #include "ur_ddi.h" #include "ur_util.hpp" @@ -27,7 +28,8 @@ class __urdlllocal proxy_layer_context_t { virtual bool isAvailable() const = 0; virtual ur_result_t init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) = 0; + codeloc_data codelocData, + api_callbacks apiCallbacks) = 0; virtual ur_result_t tearDown() = 0; }; diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 9998ae91bb..24fdfb8a1b 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -53,7 +53,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRelease __urdlllocal ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) { auto pfnAdapterRelease = context.urDdiTable.Global.pfnAdapterRelease; @@ -79,7 +79,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { auto pfnAdapterRetain = context.urDdiTable.Global.pfnAdapterRetain; @@ -520,7 +520,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { auto pfnRetain = context.urDdiTable.Device.pfnRetain; @@ -546,7 +546,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urDeviceRelease __urdlllocal ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) { auto pfnRelease = context.urDdiTable.Device.pfnRelease; @@ -815,7 +816,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { auto pfnRetain = context.urDdiTable.Context.pfnRetain; @@ -841,7 +842,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urContextRelease __urdlllocal ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) { auto pfnRelease = context.urDdiTable.Context.pfnRelease; @@ -1170,7 +1172,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { auto pfnRetain = context.urDdiTable.Mem.pfnRetain; @@ -1196,7 +1199,8 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRelease __urdlllocal ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) { auto pfnRelease = context.urDdiTable.Mem.pfnRelease; @@ -1581,7 +1585,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { auto pfnRetain = context.urDdiTable.Sampler.pfnRetain; @@ -1608,7 +1612,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( /// @brief Intercept function for urSamplerRelease __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) { auto pfnRelease = context.urDdiTable.Sampler.pfnRelease; @@ -2077,7 +2081,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { auto pfnPoolRetain = context.urDdiTable.USM.pfnPoolRetain; @@ -2103,7 +2107,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRelease __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) { auto pfnPoolRelease = context.urDdiTable.USM.pfnPoolRelease; @@ -2554,7 +2558,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { auto pfnRetain = context.urDdiTable.PhysicalMem.pfnRetain; @@ -2581,7 +2585,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( /// @brief Intercept function for urPhysicalMemRelease __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) { auto pfnRelease = context.urDdiTable.PhysicalMem.pfnRelease; @@ -2860,7 +2864,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { auto pfnRetain = context.urDdiTable.Program.pfnRetain; @@ -2886,7 +2891,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease __urdlllocal ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) { auto pfnRelease = context.urDdiTable.Program.pfnRelease; @@ -3521,7 +3527,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { auto pfnRetain = context.urDdiTable.Kernel.pfnRetain; @@ -3547,7 +3553,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRelease __urdlllocal ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) { auto pfnRelease = context.urDdiTable.Kernel.pfnRelease; @@ -4042,7 +4049,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { auto pfnRetain = context.urDdiTable.Queue.pfnRetain; @@ -4068,7 +4076,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRelease __urdlllocal ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) { auto pfnRelease = context.urDdiTable.Queue.pfnRelease; @@ -4360,7 +4369,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { auto pfnRetain = context.urDdiTable.Event.pfnRetain; @@ -4386,7 +4395,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRelease __urdlllocal ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) { auto pfnRelease = context.urDdiTable.Event.pfnRelease; @@ -7504,7 +7513,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) { auto pfnReleaseInteropExp = context.urDdiTable.BindlessImagesExp.pfnReleaseInteropExp; @@ -7796,7 +7805,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { auto pfnRetainExp = context.urDdiTable.CommandBufferExp.pfnRetainExp; @@ -7819,7 +7828,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( /// @brief Intercept function for urCommandBufferReleaseExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) { auto pfnReleaseExp = context.urDdiTable.CommandBufferExp.pfnReleaseExp; @@ -8712,7 +8721,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// @brief Intercept function for urCommandBufferReleaseCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) { auto pfnReleaseCommandExp = context.urDdiTable.CommandBufferExp.pfnReleaseCommandExp; @@ -10780,7 +10789,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetDeviceProcAddrTable( ur_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data) { + codeloc_data, api_callbacks) { ur_result_t result = UR_RESULT_SUCCESS; if (enabledLayerNames.count(nameFullValidation)) { diff --git a/source/loader/layers/validation/ur_validation_layer.hpp b/source/loader/layers/validation/ur_validation_layer.hpp index c72932c453..91a3621e50 100644 --- a/source/loader/layers/validation/ur_validation_layer.hpp +++ b/source/loader/layers/validation/ur_validation_layer.hpp @@ -37,7 +37,7 @@ class __urdlllocal context_t : public proxy_layer_context_t { } ur_result_t init(ur_dditable_t *dditable, const std::set &enabledLayerNames, - codeloc_data codelocData) override; + codeloc_data codelocData, api_callbacks) override; ur_result_t tearDown() override; private: diff --git a/source/loader/loader.def.in b/source/loader/loader.def.in index b68c1ab6c1..24000e9975 100644 --- a/source/loader/loader.def.in +++ b/source/loader/loader.def.in @@ -141,6 +141,7 @@ EXPORTS urLoaderConfigRelease urLoaderConfigRetain urLoaderConfigSetCodeLocationCallback + urLoaderConfigSetMockCallbacks urLoaderInit urLoaderTearDown urMemBufferCreate @@ -194,6 +195,7 @@ EXPORTS urPrintBufferCreateType urPrintBufferProperties urPrintBufferRegion + urPrintCallbackOverrideMode urPrintCodeLocation urPrintCommand urPrintCommandBufferAppendKernelLaunchExpParams @@ -355,6 +357,7 @@ EXPORTS urPrintLoaderConfigReleaseParams urPrintLoaderConfigRetainParams urPrintLoaderConfigSetCodeLocationCallbackParams + urPrintLoaderConfigSetMockCallbacksParams urPrintLoaderInitParams urPrintLoaderTearDownParams urPrintMapFlags @@ -374,6 +377,7 @@ EXPORTS urPrintMemType urPrintMemoryOrderCapabilityFlags urPrintMemoryScopeCapabilityFlags + urPrintMockCallbackProperties urPrintPhysicalMemCreateParams urPrintPhysicalMemFlags urPrintPhysicalMemProperties diff --git a/source/loader/loader.map.in b/source/loader/loader.map.in index fd390f7fc4..210663ace2 100644 --- a/source/loader/loader.map.in +++ b/source/loader/loader.map.in @@ -141,6 +141,7 @@ urLoaderConfigRelease; urLoaderConfigRetain; urLoaderConfigSetCodeLocationCallback; + urLoaderConfigSetMockCallbacks; urLoaderInit; urLoaderTearDown; urMemBufferCreate; @@ -194,6 +195,7 @@ urPrintBufferCreateType; urPrintBufferProperties; urPrintBufferRegion; + urPrintCallbackOverrideMode; urPrintCodeLocation; urPrintCommand; urPrintCommandBufferAppendKernelLaunchExpParams; @@ -355,6 +357,7 @@ urPrintLoaderConfigReleaseParams; urPrintLoaderConfigRetainParams; urPrintLoaderConfigSetCodeLocationCallbackParams; + urPrintLoaderConfigSetMockCallbacksParams; urPrintLoaderInitParams; urPrintLoaderTearDownParams; urPrintMapFlags; @@ -374,6 +377,7 @@ urPrintMemType; urPrintMemoryOrderCapabilityFlags; urPrintMemoryScopeCapabilityFlags; + urPrintMockCallbackProperties; urPrintPhysicalMemCreateParams; urPrintPhysicalMemFlags; urPrintPhysicalMemProperties; diff --git a/source/loader/ur_callbacks.hpp b/source/loader/ur_callbacks.hpp new file mode 100644 index 0000000000..af3b827b5f --- /dev/null +++ b/source/loader/ur_callbacks.hpp @@ -0,0 +1,91 @@ +/* + * + * Copyright (C) 2023 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file ur_codeloc.hpp + * + */ + +#ifndef UR_CALLBACKS_HPP +#define UR_CALLBACKS_HPP 1 + +#include + +#include +#include + +struct api_callbacks { + void set_before_callback(std::string name, ur_mock_callback_t callback) { + beforeCallbacks[name] = callback; + } + + ur_mock_callback_t get_before_callback(std::string name) { + auto callback = beforeCallbacks.find(name); + + if (callback != beforeCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + void set_replace_callback(std::string name, ur_mock_callback_t callback) { + replaceCallbacks[name] = callback; + } + + ur_mock_callback_t get_replace_callback(std::string name) { + auto callback = replaceCallbacks.find(name); + + if (callback != replaceCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + void set_after_callback(std::string name, ur_mock_callback_t callback) { + afterCallbacks[name] = callback; + } + + ur_mock_callback_t get_after_callback(std::string name) { + auto callback = afterCallbacks.find(name); + + if (callback != afterCallbacks.end()) { + return callback->second; + } + return nullptr; + } + + ur_result_t + add_callback(ur_mock_callback_properties_t *callback_properties) { + if (!callback_properties->pCallback) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + switch (callback_properties->mode) { + case UR_CALLBACK_OVERRIDE_MODE_BEFORE: + set_before_callback(callback_properties->name, + callback_properties->pCallback); + break; + case UR_CALLBACK_OVERRIDE_MODE_REPLACE: + set_replace_callback(callback_properties->name, + callback_properties->pCallback); + break; + case UR_CALLBACK_OVERRIDE_MODE_AFTER: + set_after_callback(callback_properties->name, + callback_properties->pCallback); + break; + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + return UR_RESULT_SUCCESS; + } + + private: + std::unordered_map beforeCallbacks; + std::unordered_map replaceCallbacks; + std::unordered_map afterCallbacks; +}; + +#endif /* UR_CODELOC_HPP */ diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index fb392dd607..081b6f044b 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -85,7 +85,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRelease __urdlllocal ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -102,13 +102,20 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( // forward to device-platform result = pfnAdapterRelease(hAdapter); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_adapter_factory.release(hAdapter); + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterRetain __urdlllocal ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -125,6 +132,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } @@ -589,7 +602,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( /// @brief Intercept function for urDeviceRetain __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -606,13 +619,20 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urDeviceRelease __urdlllocal ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -629,6 +649,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( // forward to device-platform result = pfnRelease(hDevice); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_device_factory.release(hDevice); + return result; } @@ -865,7 +892,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate( /// @brief Intercept function for urContextRetain __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -882,13 +909,20 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urContextRelease __urdlllocal ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -905,6 +939,13 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( // forward to device-platform result = pfnRelease(hContext); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_context_factory.release(hContext); + return result; } @@ -1169,7 +1210,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRetain __urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1186,13 +1228,20 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urMemRelease __urdlllocal ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1209,6 +1258,13 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( // forward to device-platform result = pfnRelease(hMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_mem_factory.release(hMem); + return result; } @@ -1523,7 +1579,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreate( /// @brief Intercept function for urSamplerRetain __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1540,6 +1596,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } @@ -1547,7 +1609,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( /// @brief Intercept function for urSamplerRelease __urdlllocal ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1564,6 +1626,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( // forward to device-platform result = pfnRelease(hSampler); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_sampler_factory.release(hSampler); + return result; } @@ -1958,7 +2027,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRetain __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1975,13 +2044,19 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urUSMPoolRelease __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -1998,6 +2073,13 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( // forward to device-platform result = pfnPoolRelease(pPool); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_usm_pool_factory.release(pPool); + return result; } @@ -2343,7 +2425,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate( /// @brief Intercept function for urPhysicalMemRetain __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2362,6 +2444,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } @@ -2369,7 +2457,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( /// @brief Intercept function for urPhysicalMemRelease __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2388,6 +2476,13 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( // forward to device-platform result = pfnRelease(hPhysicalMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_physical_mem_factory.release(hPhysicalMem); + return result; } @@ -2591,7 +2686,8 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRetain __urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2608,13 +2704,20 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urProgramRelease __urdlllocal ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -2631,6 +2734,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( // forward to device-platform result = pfnRelease(hProgram); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_program_factory.release(hProgram); + return result; } @@ -3180,7 +3290,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetSubGroupInfo( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3197,13 +3307,20 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRelease __urdlllocal ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3220,6 +3337,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( // forward to device-platform result = pfnRelease(hKernel); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_kernel_factory.release(hKernel); + return result; } @@ -3625,7 +3749,8 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRetain __urdlllocal ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3642,13 +3767,20 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urQueueRelease __urdlllocal ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3665,6 +3797,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( // forward to device-platform result = pfnRelease(hQueue); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_queue_factory.release(hQueue); + return result; } @@ -3932,7 +4071,7 @@ __urdlllocal ur_result_t UR_APICALL urEventWait( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRetain __urdlllocal ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3949,13 +4088,19 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventRelease __urdlllocal ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -3972,6 +4117,13 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( // forward to device-platform result = pfnRelease(hEvent); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_event_factory.release(hEvent); + return result; } @@ -6349,7 +6501,7 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -6374,6 +6526,13 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( // forward to device-platform result = pfnReleaseInteropExp(hContext, hDevice, hInteropMem); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_exp_interop_mem_factory.release(hInteropMem); + return result; } @@ -6640,7 +6799,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferCreateExp( /// @brief Intercept function for urCommandBufferRetainExp __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -6661,6 +6820,12 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // TODO: do we need to ref count the loader handles? + return result; } @@ -6668,7 +6833,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( /// @brief Intercept function for urCommandBufferReleaseExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -6689,6 +6854,13 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( // forward to device-platform result = pfnReleaseExp(hCommandBuffer); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_exp_command_buffer_factory.release(hCommandBuffer); + return result; } @@ -7413,7 +7585,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// @brief Intercept function for urCommandBufferReleaseCommandExp __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) { ur_result_t result = UR_RESULT_SUCCESS; @@ -7435,6 +7607,13 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( // forward to device-platform result = pfnReleaseCommandExp(hCommand); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + // release loader handle + ur_exp_command_buffer_command_factory.release(hCommand); + return result; } diff --git a/source/loader/ur_lib.cpp b/source/loader/ur_lib.cpp index d2ed4853a8..ace9f12c75 100644 --- a/source/loader/ur_lib.cpp +++ b/source/loader/ur_lib.cpp @@ -64,7 +64,8 @@ void context_t::parseEnvEnabledLayers() { void context_t::initLayers() const { for (auto &l : layers) { if (l->isAvailable()) { - l->init(&context->urDdiTable, enabledLayerNames, codelocData); + l->init(&context->urDdiTable, enabledLayerNames, codelocData, + apiCallbacks); } } } @@ -93,6 +94,7 @@ __urdlllocal ur_result_t context_t::Init( if (hLoaderConfig) { codelocData = hLoaderConfig->codelocData; + apiCallbacks = hLoaderConfig->apiCallbacks; enabledLayerNames.merge(hLoaderConfig->getEnabledLayerNames()); } @@ -215,6 +217,38 @@ urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, return UR_RESULT_SUCCESS; } +ur_result_t urLoaderConfigSetMockCallbacks( + ur_loader_config_handle_t hLoaderConfig, + ur_mock_callback_properties_t *pCallbackProperties) { + if (!hLoaderConfig) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + + if (!pCallbackProperties) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + auto result = hLoaderConfig->apiCallbacks.add_callback(pCallbackProperties); + if (result != UR_RESULT_SUCCESS) { + return result; + } + + auto nextProps = + static_cast(pCallbackProperties->pNext); + while (nextProps) { + if (nextProps->stype != UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES) { + break; + } + auto result = hLoaderConfig->apiCallbacks.add_callback( + reinterpret_cast(nextProps)); + if (result != UR_RESULT_SUCCESS) { + return result; + } + nextProps = static_cast(nextProps->pNext); + } + return UR_RESULT_SUCCESS; +} + ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, ur_device_type_t DeviceType, uint32_t NumEntries, diff --git a/source/loader/ur_lib.hpp b/source/loader/ur_lib.hpp index 839c0041d9..2bc9dd3244 100644 --- a/source/loader/ur_lib.hpp +++ b/source/loader/ur_lib.hpp @@ -14,11 +14,13 @@ #define UR_LOADER_LIB_H 1 #include "ur_api.h" +#include "ur_callbacks.hpp" #include "ur_codeloc.hpp" #include "ur_ddi.h" #include "ur_proxy_layer.hpp" #include "ur_util.hpp" +#include "mock/ur_mock_layer.hpp" #include "validation/ur_validation_layer.hpp" #if UR_ENABLE_TRACING #include "tracing/ur_tracing_layer.hpp" @@ -48,6 +50,7 @@ struct ur_loader_config_handle_t_ { std::set &getEnabledLayerNames() { return enabledLayers; } codeloc_data codelocData; + api_callbacks apiCallbacks; }; namespace ur_lib { @@ -75,13 +78,14 @@ class __urdlllocal context_t { &ur_tracing_layer::context, #endif #if UR_ENABLE_SANITIZER - &ur_sanitizer_layer::context + &ur_sanitizer_layer::context, #endif - }; + &ur_mock_layer::context}; std::string availableLayers; std::set enabledLayerNames; codeloc_data codelocData; + api_callbacks apiCallbacks; bool layerExists(const std::string &layerName) const; void parseEnvEnabledLayers(); @@ -104,6 +108,9 @@ ur_result_t urLoaderConfigSetCodeLocationCallback(ur_loader_config_handle_t hLoaderConfig, ur_code_location_callback_t pfnCodeloc, void *pUserData); +ur_result_t urLoaderConfigSetMockCallbacks( + ur_loader_config_handle_t hLoaderConfig, + ur_mock_callback_properties_t *pCallbackProperties); ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform, ur_device_type_t DeviceType, diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 668fbd07ad..16a8045d16 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -52,7 +52,7 @@ ur_result_t UR_APICALL urLoaderConfigCreate( /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRetain( ur_loader_config_handle_t - hLoaderConfig ///< [in] loader config handle to retain + hLoaderConfig ///< [in][retain] loader config handle to retain ) try { return ur_lib::urLoaderConfigRetain(hLoaderConfig); } catch (...) { @@ -76,7 +76,8 @@ ur_result_t UR_APICALL urLoaderConfigRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRelease( - ur_loader_config_handle_t hLoaderConfig ///< [in] config handle to release + ur_loader_config_handle_t + hLoaderConfig ///< [in][release] config handle to release ) try { return ur_lib::urLoaderConfigRelease(hLoaderConfig); } catch (...) { @@ -191,6 +192,40 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( return exceptionToResult(std::current_exception()); } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Set a callback to be called before, after or instead of a given entry +/// point +/// +/// @details +/// - The callback layer will pass the function's parameter struct (e.g. +/// **::ur_adapter_get_params_t**) to the ::ur_mock_callback_t so +/// parameters can be accessed and modified. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pCallbackProperties` +/// + `NULL == pCallbackProperties->name` +/// + `NULL == pCallbackProperties->pCallback` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_CALLBACK_OVERRIDE_MODE_AFTER < pCallbackProperties->mode` +ur_result_t UR_APICALL urLoaderConfigSetMockCallbacks( + ur_loader_config_handle_t + hLoaderConfig, ///< [in] Handle to config object the layer will be enabled for. + ur_mock_callback_properties_t + *pCallbackProperties ///< [in] Pointer to callback properties struct. + ) try { + return ur_lib::urLoaderConfigSetMockCallbacks(hLoaderConfig, + pCallbackProperties); +} catch (...) { + return exceptionToResult(std::current_exception()); +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -313,7 +348,7 @@ ur_result_t UR_APICALL urAdapterGet( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) try { auto pfnAdapterRelease = ur_lib::context->urDdiTable.Global.pfnAdapterRelease; @@ -340,7 +375,7 @@ ur_result_t UR_APICALL urAdapterRelease( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) try { auto pfnAdapterRetain = ur_lib::context->urDdiTable.Global.pfnAdapterRetain; if (nullptr == pfnAdapterRetain) { @@ -906,7 +941,7 @@ ur_result_t UR_APICALL urDeviceGetInfo( /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) try { auto pfnRetain = ur_lib::context->urDdiTable.Device.pfnRetain; if (nullptr == pfnRetain) { @@ -942,7 +977,8 @@ ur_result_t UR_APICALL urDeviceRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) try { auto pfnRelease = ur_lib::context->urDdiTable.Device.pfnRelease; if (nullptr == pfnRelease) { @@ -1246,7 +1282,7 @@ ur_result_t UR_APICALL urContextCreate( /// + `NULL == hContext` ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) try { auto pfnRetain = ur_lib::context->urDdiTable.Context.pfnRetain; if (nullptr == pfnRetain) { @@ -1278,7 +1314,8 @@ ur_result_t UR_APICALL urContextRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hContext` ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) try { auto pfnRelease = ur_lib::context->urDdiTable.Context.pfnRelease; if (nullptr == pfnRelease) { @@ -1609,7 +1646,8 @@ ur_result_t UR_APICALL urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Mem.pfnRetain; if (nullptr == pfnRetain) { @@ -1639,7 +1677,8 @@ ur_result_t UR_APICALL urMemRetain( /// - ::UR_RESULT_ERROR_INVALID_MEM_OBJECT /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) try { auto pfnRelease = ur_lib::context->urDdiTable.Mem.pfnRelease; if (nullptr == pfnRelease) { @@ -2005,7 +2044,7 @@ ur_result_t UR_APICALL urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Sampler.pfnRetain; if (nullptr == pfnRetain) { @@ -2037,7 +2076,7 @@ ur_result_t UR_APICALL urSamplerRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) try { auto pfnRelease = ur_lib::context->urDdiTable.Sampler.pfnRelease; if (nullptr == pfnRelease) { @@ -2474,7 +2513,7 @@ ur_result_t UR_APICALL urUSMPoolCreate( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) try { auto pfnPoolRetain = ur_lib::context->urDdiTable.USM.pfnPoolRetain; if (nullptr == pfnPoolRetain) { @@ -2504,7 +2543,7 @@ ur_result_t UR_APICALL urUSMPoolRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) try { auto pfnPoolRelease = ur_lib::context->urDdiTable.USM.pfnPoolRelease; if (nullptr == pfnPoolRelease) { @@ -2866,7 +2905,7 @@ ur_result_t UR_APICALL urPhysicalMemCreate( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) try { auto pfnRetain = ur_lib::context->urDdiTable.PhysicalMem.pfnRetain; if (nullptr == pfnRetain) { @@ -2890,7 +2929,7 @@ ur_result_t UR_APICALL urPhysicalMemRetain( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) try { auto pfnRelease = ur_lib::context->urDdiTable.PhysicalMem.pfnRelease; if (nullptr == pfnRelease) { @@ -3162,7 +3201,8 @@ ur_result_t UR_APICALL urProgramLink( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) try { auto pfnRetain = ur_lib::context->urDdiTable.Program.pfnRetain; if (nullptr == pfnRetain) { @@ -3195,7 +3235,8 @@ ur_result_t UR_APICALL urProgramRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) try { auto pfnRelease = ur_lib::context->urDdiTable.Program.pfnRelease; if (nullptr == pfnRelease) { @@ -3786,7 +3827,7 @@ ur_result_t UR_APICALL urKernelGetSubGroupInfo( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) try { auto pfnRetain = ur_lib::context->urDdiTable.Kernel.pfnRetain; if (nullptr == pfnRetain) { @@ -3819,7 +3860,8 @@ ur_result_t UR_APICALL urKernelRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) try { auto pfnRelease = ur_lib::context->urDdiTable.Kernel.pfnRelease; if (nullptr == pfnRelease) { @@ -4283,7 +4325,8 @@ ur_result_t UR_APICALL urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) try { auto pfnRetain = ur_lib::context->urDdiTable.Queue.pfnRetain; if (nullptr == pfnRetain) { @@ -4320,7 +4363,8 @@ ur_result_t UR_APICALL urQueueRetain( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) try { auto pfnRelease = ur_lib::context->urDdiTable.Queue.pfnRelease; if (nullptr == pfnRelease) { @@ -4640,7 +4684,7 @@ ur_result_t UR_APICALL urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) try { auto pfnRetain = ur_lib::context->urDdiTable.Event.pfnRetain; if (nullptr == pfnRetain) { @@ -4671,7 +4715,7 @@ ur_result_t UR_APICALL urEventRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) try { auto pfnRelease = ur_lib::context->urDdiTable.Event.pfnRelease; if (nullptr == pfnRelease) { @@ -7054,7 +7098,7 @@ ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) try { auto pfnReleaseInteropExp = ur_lib::context->urDdiTable.BindlessImagesExp.pfnReleaseInteropExp; @@ -7291,7 +7335,7 @@ ur_result_t UR_APICALL urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) try { auto pfnRetainExp = ur_lib::context->urDdiTable.CommandBufferExp.pfnRetainExp; @@ -7320,7 +7364,7 @@ ur_result_t UR_APICALL urCommandBufferRetainExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) try { auto pfnReleaseExp = ur_lib::context->urDdiTable.CommandBufferExp.pfnReleaseExp; @@ -8120,7 +8164,7 @@ ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) try { auto pfnReleaseCommandExp = ur_lib::context->urDdiTable.CommandBufferExp.pfnReleaseCommandExp; diff --git a/source/loader/ur_print.cpp b/source/loader/ur_print.cpp index 79107c733d..2726de9117 100644 --- a/source/loader/ur_print.cpp +++ b/source/loader/ur_print.cpp @@ -114,6 +114,23 @@ ur_result_t urPrintCodeLocation(const struct ur_code_location_t params, return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintCallbackOverrideMode(enum ur_callback_override_mode_t value, + char *buffer, const size_t buff_size, + size_t *out_size) { + std::stringstream ss; + ss << value; + return str_copy(&ss, buffer, buff_size, out_size); +} + +ur_result_t +urPrintMockCallbackProperties(const struct ur_mock_callback_properties_t params, + char *buffer, const size_t buff_size, + size_t *out_size) { + std::stringstream ss; + ss << params; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintAdapterInfo(enum ur_adapter_info_t value, char *buffer, const size_t buff_size, size_t *out_size) { std::stringstream ss; @@ -1960,6 +1977,14 @@ ur_result_t urPrintLoaderConfigSetCodeLocationCallbackParams( return str_copy(&ss, buffer, buff_size, out_size); } +ur_result_t urPrintLoaderConfigSetMockCallbacksParams( + const struct ur_loader_config_set_mock_callbacks_params_t *params, + char *buffer, const size_t buff_size, size_t *out_size) { + std::stringstream ss; + ss << params; + return str_copy(&ss, buffer, buff_size, out_size); +} + ur_result_t urPrintMemImageCreateParams(const struct ur_mem_image_create_params_t *params, char *buffer, const size_t buff_size, diff --git a/source/ur_api.cpp b/source/ur_api.cpp index f7b4bb017f..f6a21c7ac8 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -48,7 +48,7 @@ ur_result_t UR_APICALL urLoaderConfigCreate( /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRetain( ur_loader_config_handle_t - hLoaderConfig ///< [in] loader config handle to retain + hLoaderConfig ///< [in][retain] loader config handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -71,7 +71,8 @@ ur_result_t UR_APICALL urLoaderConfigRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hLoaderConfig` ur_result_t UR_APICALL urLoaderConfigRelease( - ur_loader_config_handle_t hLoaderConfig ///< [in] config handle to release + ur_loader_config_handle_t + hLoaderConfig ///< [in][release] config handle to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -180,6 +181,38 @@ ur_result_t UR_APICALL urLoaderConfigSetCodeLocationCallback( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Set a callback to be called before, after or instead of a given entry +/// point +/// +/// @details +/// - The callback layer will pass the function's parameter struct (e.g. +/// **::ur_adapter_get_params_t**) to the ::ur_mock_callback_t so +/// parameters can be accessed and modified. +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_DEVICE_LOST +/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC +/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE +/// + `NULL == hLoaderConfig` +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// + `NULL == pCallbackProperties` +/// + `NULL == pCallbackProperties->name` +/// + `NULL == pCallbackProperties->pCallback` +/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION +/// + `::UR_CALLBACK_OVERRIDE_MODE_AFTER < pCallbackProperties->mode` +ur_result_t UR_APICALL urLoaderConfigSetMockCallbacks( + ur_loader_config_handle_t + hLoaderConfig, ///< [in] Handle to config object the layer will be enabled for. + ur_mock_callback_properties_t + *pCallbackProperties ///< [in] Pointer to callback properties struct. +) { + ur_result_t result = UR_RESULT_SUCCESS; + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Initialize the 'oneAPI' loader /// @@ -284,7 +317,7 @@ ur_result_t UR_APICALL urAdapterGet( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRelease( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to release + ur_adapter_handle_t hAdapter ///< [in][release] Adapter handle to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -304,7 +337,7 @@ ur_result_t UR_APICALL urAdapterRelease( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hAdapter` ur_result_t UR_APICALL urAdapterRetain( - ur_adapter_handle_t hAdapter ///< [in] Adapter handle to retain + ur_adapter_handle_t hAdapter ///< [in][retain] Adapter handle to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -794,7 +827,7 @@ ur_result_t UR_APICALL urDeviceGetInfo( /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRetain( ur_device_handle_t - hDevice ///< [in] handle of the device to get a reference of. + hDevice ///< [in][retain] handle of the device to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -824,7 +857,8 @@ ur_result_t UR_APICALL urDeviceRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hDevice` ur_result_t UR_APICALL urDeviceRelease( - ur_device_handle_t hDevice ///< [in] handle of the device to release. + ur_device_handle_t + hDevice ///< [in][release] handle of the device to release. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1081,7 +1115,7 @@ ur_result_t UR_APICALL urContextCreate( /// + `NULL == hContext` ur_result_t UR_APICALL urContextRetain( ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. + hContext ///< [in][retain] handle of the context to get a reference of. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1107,7 +1141,8 @@ ur_result_t UR_APICALL urContextRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hContext` ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. + ur_context_handle_t + hContext ///< [in][release] handle of the context to release. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1391,7 +1426,8 @@ ur_result_t UR_APICALL urMemBufferCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access + ur_mem_handle_t + hMem ///< [in][retain] handle of the memory object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1415,7 +1451,8 @@ ur_result_t UR_APICALL urMemRetain( /// - ::UR_RESULT_ERROR_INVALID_MEM_OBJECT /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release + ur_mem_handle_t + hMem ///< [in][release] handle of the memory object to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1726,7 +1763,7 @@ ur_result_t UR_APICALL urSamplerCreate( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRetain( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to get access + hSampler ///< [in][retain] handle of the sampler object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -1752,7 +1789,7 @@ ur_result_t UR_APICALL urSamplerRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urSamplerRelease( ur_sampler_handle_t - hSampler ///< [in] handle of the sampler object to release + hSampler ///< [in][release] handle of the sampler object to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2124,7 +2161,7 @@ ur_result_t UR_APICALL urUSMPoolCreate( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRetain( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][retain] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2148,7 +2185,7 @@ ur_result_t UR_APICALL urUSMPoolRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == pPool` ur_result_t UR_APICALL urUSMPoolRelease( - ur_usm_pool_handle_t pPool ///< [in] pointer to USM memory pool + ur_usm_pool_handle_t pPool ///< [in][release] pointer to USM memory pool ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2447,7 +2484,7 @@ ur_result_t UR_APICALL urPhysicalMemCreate( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRetain( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to retain. + hPhysicalMem ///< [in][retain] handle of the physical memory object to retain. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2465,7 +2502,7 @@ ur_result_t UR_APICALL urPhysicalMemRetain( /// + `NULL == hPhysicalMem` ur_result_t UR_APICALL urPhysicalMemRelease( ur_physical_mem_handle_t - hPhysicalMem ///< [in] handle of the physical memory object to release. + hPhysicalMem ///< [in][release] handle of the physical memory object to release. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2699,7 +2736,8 @@ ur_result_t UR_APICALL urProgramLink( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t hProgram ///< [in] handle for the Program to retain + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -2726,7 +2764,8 @@ ur_result_t UR_APICALL urProgramRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hProgram` ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t hProgram ///< [in] handle for the Program to release + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3220,7 +3259,7 @@ ur_result_t UR_APICALL urKernelGetSubGroupInfo( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain + ur_kernel_handle_t hKernel ///< [in][retain] handle for the Kernel to retain ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3247,7 +3286,8 @@ ur_result_t UR_APICALL urKernelRetain( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hKernel` ur_result_t UR_APICALL urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release + ur_kernel_handle_t + hKernel ///< [in][release] handle for the Kernel to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3638,7 +3678,8 @@ ur_result_t UR_APICALL urQueueCreate( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRetain( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to get access + ur_queue_handle_t + hQueue ///< [in][retain] handle of the queue object to get access ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3669,7 +3710,8 @@ ur_result_t UR_APICALL urQueueRetain( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urQueueRelease( - ur_queue_handle_t hQueue ///< [in] handle of the queue object to release + ur_queue_handle_t + hQueue ///< [in][release] handle of the queue object to release ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3936,7 +3978,7 @@ ur_result_t UR_APICALL urEventWait( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRetain( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][retain] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -3961,7 +4003,7 @@ ur_result_t UR_APICALL urEventRetain( /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urEventRelease( - ur_event_handle_t hEvent ///< [in] handle of the event object + ur_event_handle_t hEvent ///< [in][release] handle of the event object ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -6012,7 +6054,7 @@ ur_result_t UR_APICALL urBindlessImagesReleaseInteropExp( ur_context_handle_t hContext, ///< [in] handle of the context object ur_device_handle_t hDevice, ///< [in] handle of the device object ur_exp_interop_mem_handle_t - hInteropMem ///< [in] handle of interop memory to be freed + hInteropMem ///< [in][release] handle of interop memory to be freed ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -6200,7 +6242,7 @@ ur_result_t UR_APICALL urCommandBufferCreateExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferRetainExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][retain] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -6222,7 +6264,7 @@ ur_result_t UR_APICALL urCommandBufferRetainExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferReleaseExp( ur_exp_command_buffer_handle_t - hCommandBuffer ///< [in] Handle of the command-buffer object. + hCommandBuffer ///< [in][release] Handle of the command-buffer object. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; @@ -6878,7 +6920,7 @@ ur_result_t UR_APICALL urCommandBufferRetainCommandExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( ur_exp_command_buffer_command_handle_t - hCommand ///< [in] Handle of the command-buffer command. + hCommand ///< [in][release] Handle of the command-buffer command. ) { ur_result_t result = UR_RESULT_SUCCESS; return result; diff --git a/test/layers/CMakeLists.txt b/test/layers/CMakeLists.txt index 2c10a08518..c7feca2ee3 100644 --- a/test/layers/CMakeLists.txt +++ b/test/layers/CMakeLists.txt @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(validation) +add_subdirectory(mock) if(UR_ENABLE_TRACING) add_subdirectory(tracing) diff --git a/test/layers/mock/CMakeLists.txt b/test/layers/mock/CMakeLists.txt new file mode 100644 index 0000000000..574e5a0fe9 --- /dev/null +++ b/test/layers/mock/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (C) 2024 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(MOCK_TEST_NAME test-mock) + +add_ur_executable(${MOCK_TEST_NAME} mock.cpp) +target_link_libraries(${MOCK_TEST_NAME} + PRIVATE + ${PROJECT_NAME}::loader + ${PROJECT_NAME}::headers + ${PROJECT_NAME}::testing + GTest::gtest_main) + +add_test(NAME ${MOCK_TEST_NAME} + COMMAND ${MOCK_TEST_NAME} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + +set_tests_properties(${MOCK_TEST_NAME} PROPERTIES LABELS "mock") + +set_property(TEST ${MOCK_TEST_NAME} PROPERTY ENVIRONMENT + "UR_ADAPTERS_FORCE_LOAD=\"$\"") diff --git a/test/layers/mock/mock.cpp b/test/layers/mock/mock.cpp new file mode 100644 index 0000000000..75fcc5ef32 --- /dev/null +++ b/test/layers/mock/mock.cpp @@ -0,0 +1,141 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file codeloc.cpp + * + */ + +#include "uur/raii.h" +#include +#include + +TEST(Mock, NullProperties) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, nullptr), + UR_RESULT_ERROR_INVALID_NULL_POINTER); +} + +TEST(Mock, NullCallback) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + + ur_mock_callback_properties_t callback_properties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_REPLACE, nullptr}; + + ASSERT_EQ( + urLoaderConfigSetMockCallbacks(loader_config, &callback_properties), + UR_RESULT_ERROR_INVALID_NULL_POINTER); +} + +ur_result_t generic_callback(void *) { return UR_RESULT_SUCCESS; } + +TEST(Mock, NullHandle) { + ur_mock_callback_properties_t callback_properties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_REPLACE, &generic_callback}; + + ASSERT_EQ(urLoaderConfigSetMockCallbacks(nullptr, &callback_properties), + UR_RESULT_ERROR_INVALID_NULL_HANDLE); +} + +TEST(Mock, DefaultBehavior) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); + + // Set up as far as device and check we're getting sensible, different + // handles created. + ur_adapter_handle_t adapter = nullptr; + ur_platform_handle_t platform = nullptr; + ur_device_handle_t device = nullptr; + + ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); + ASSERT_EQ(urPlatformGet(&adapter, 1, 1, &platform, nullptr), + UR_RESULT_SUCCESS); + ASSERT_EQ(urDeviceGet(platform, UR_DEVICE_TYPE_ALL, 1, &device, nullptr), + UR_RESULT_SUCCESS); + + ASSERT_NE(adapter, nullptr); + ASSERT_NE(platform, nullptr); + ASSERT_NE(device, nullptr); + + ASSERT_NE(static_cast(adapter), static_cast(platform)); + ASSERT_NE(static_cast(adapter), static_cast(device)); + ASSERT_NE(static_cast(platform), static_cast(device)); + + ASSERT_EQ(urDeviceRelease(device), UR_RESULT_SUCCESS); +} + +void checkPreInitAdapter(ur_adapter_handle_t adapter) { + ur_adapter_handle_t preInitAdapter = + reinterpret_cast(0xF00DCAFE); + ASSERT_EQ(adapter, preInitAdapter); +} + +ur_result_t beforeUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + checkPreInitAdapter(**params->pphAdapters); + return UR_RESULT_SUCCESS; +} + +ur_result_t replaceUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + **params->pphAdapters = reinterpret_cast(0xDEADBEEF); + return UR_RESULT_SUCCESS; +} + +void checkPostInitAdapter(ur_adapter_handle_t adapter) { + ur_adapter_handle_t postInitAdapter = + reinterpret_cast(0xDEADBEEF); + ASSERT_EQ(adapter, postInitAdapter); +} + +ur_result_t afterUrAdapterGet(void *pParams) { + auto params = reinterpret_cast(pParams); + checkPostInitAdapter(**params->pphAdapters); + return UR_RESULT_SUCCESS; +} + +TEST(Mock, Callbacks) { + uur::raii::LoaderConfig loader_config; + ASSERT_EQ(urLoaderConfigCreate(loader_config.ptr()), UR_RESULT_SUCCESS); + + // This callback is set up to check *phAdapters is still the pre-call + // init value we set below + ur_mock_callback_properties_t adapterGetBeforeProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, nullptr, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_BEFORE, &beforeUrAdapterGet}; + + // This callback is set up to return a distinct test value in phAdapters + // rather than the default generic handle + ur_mock_callback_properties_t adapterGetReplaceProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, &adapterGetBeforeProperties, + "urAdapterGet", UR_CALLBACK_OVERRIDE_MODE_REPLACE, + &replaceUrAdapterGet}; + + // This callback is set up to check our replace callback did its job + ur_mock_callback_properties_t adapterGetAfterProperties = { + UR_STRUCTURE_TYPE_MOCK_CALLBACK_PROPERTIES, + &adapterGetReplaceProperties, "urAdapterGet", + UR_CALLBACK_OVERRIDE_MODE_AFTER, &afterUrAdapterGet}; + + ASSERT_EQ(urLoaderConfigSetMockCallbacks(loader_config, + &adapterGetAfterProperties), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderConfigEnableLayer(loader_config, "UR_LAYER_MOCK"), + UR_RESULT_SUCCESS); + ASSERT_EQ(urLoaderInit(0, loader_config), UR_RESULT_SUCCESS); + + ur_adapter_handle_t adapter = + reinterpret_cast(0xF00DCAFE); + ASSERT_EQ(urAdapterGet(1, &adapter, nullptr), UR_RESULT_SUCCESS); +}