From bf45c900230b30b9d68ebdc07c86e28e209faae5 Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Tue, 19 Aug 2025 15:51:33 +0100 Subject: [PATCH] [UR][Offload] Various small fixes for offload adapter Some functions were added, the barrier event now is reference counted (rather than being dropped when the event is destroyed) and empty events no longer cause an error. --- .../source/adapters/offload/context.hpp | 18 ++++++- .../source/adapters/offload/enqueue.cpp | 48 +++++++++++++++++-- .../source/adapters/offload/event.cpp | 8 ++-- .../source/adapters/offload/queue.cpp | 4 ++ .../source/adapters/offload/queue.hpp | 5 +- .../adapters/offload/ur_interface_loader.cpp | 6 +-- .../source/adapters/offload/usm.cpp | 13 +++-- 7 files changed, 84 insertions(+), 18 deletions(-) diff --git a/unified-runtime/source/adapters/offload/context.hpp b/unified-runtime/source/adapters/offload/context.hpp index 38857446c47f8..b40d17ad3ae9c 100644 --- a/unified-runtime/source/adapters/offload/context.hpp +++ b/unified-runtime/source/adapters/offload/context.hpp @@ -17,6 +17,11 @@ #include #include +struct alloc_info_t { + ol_alloc_type_t Type; + size_t Size; +}; + struct ur_context_handle_t_ : RefCounted { ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} { urDeviceRetain(Device); @@ -24,5 +29,16 @@ struct ur_context_handle_t_ : RefCounted { ~ur_context_handle_t_() { urDeviceRelease(Device); } ur_device_handle_t Device; - std::unordered_map AllocTypeMap; + std::unordered_map AllocTypeMap; + + std::optional getAllocType(const void *UsmPtr) { + for (auto &pair : AllocTypeMap) { + if (UsmPtr >= pair.first && + reinterpret_cast(UsmPtr) < + reinterpret_cast(pair.first) + pair.second.Size) { + return pair.second; + } + } + return std::nullopt; + } }; diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 6b9a013538bb4..87bd45eb3f817 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -93,16 +93,19 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, phEvent)); if constexpr (Barrier) { - ol_event_handle_t BarrierEvent; + ur_event_handle_t BarrierEvent; if (phEvent) { - BarrierEvent = (*phEvent)->OffloadEvent; + BarrierEvent = *phEvent; + urEventRetain(BarrierEvent); } else { - OL_RETURN_ON_ERR(olCreateEvent(TargetQueue, &BarrierEvent)); + OL_RETURN_ON_ERR(makeEvent(TYPE, TargetQueue, hQueue, &BarrierEvent)); } // Ensure any newly created work waits on this barrier if (hQueue->Barrier) { - OL_RETURN_ON_ERR(olDestroyEvent(hQueue->Barrier)); + if (auto Err = urEventRelease(hQueue->Barrier)) { + return Err; + } } hQueue->Barrier = BarrierEvent; @@ -114,7 +117,7 @@ ur_result_t doWait(ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, if (Q == TargetQueue) { continue; } - OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent, 1)); + OL_RETURN_ON_ERR(olWaitEvents(Q, &BarrierEvent->OffloadEvent, 1)); } } @@ -260,6 +263,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( blockingWrite, numEventsInWaitList, phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( + ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc, + ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + char *DevPtrSrc = + reinterpret_cast(std::get(hBufferSrc->Mem).Ptr); + char *DevPtrDst = + reinterpret_cast(std::get(hBufferDst->Mem).Ptr); + + return doMemcpy(UR_COMMAND_MEM_BUFFER_COPY, hQueue, DevPtrDst + dstOffset, + hQueue->OffloadDevice, DevPtrSrc + srcOffset, + hQueue->OffloadDevice, size, false, numEventsInWaitList, + phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, bool blockingRead, size_t count, size_t offset, void *pDst, @@ -366,3 +385,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( return Result; } + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( + ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc, + size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + auto GetDevice = [&](const void *Ptr) { + auto Res = hQueue->UrContext->getAllocType(Ptr); + if (!Res) + return Adapter->HostDevice; + return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice + : hQueue->OffloadDevice; + }; + + return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc, + GetDevice(pSrc), size, blocking, numEventsInWaitList, + phEventWaitList, phEvent); + + return UR_RESULT_SUCCESS; +} diff --git a/unified-runtime/source/adapters/offload/event.cpp b/unified-runtime/source/adapters/offload/event.cpp index aab41ed3d2d0e..ee326df79dd6f 100644 --- a/unified-runtime/source/adapters/offload/event.cpp +++ b/unified-runtime/source/adapters/offload/event.cpp @@ -64,9 +64,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { if (--hEvent->RefCount == 0) { - auto Res = olDestroyEvent(hEvent->OffloadEvent); - if (Res) { - return offloadResultToUR(Res); + if (hEvent->OffloadEvent) { + auto Res = olDestroyEvent(hEvent->OffloadEvent); + if (Res) { + return offloadResultToUR(Res); + } } delete hEvent; } diff --git a/unified-runtime/source/adapters/offload/queue.cpp b/unified-runtime/source/adapters/offload/queue.cpp index 43647d0041496..26a5d34e2ed0c 100644 --- a/unified-runtime/source/adapters/offload/queue.cpp +++ b/unified-runtime/source/adapters/offload/queue.cpp @@ -105,3 +105,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( const ur_queue_native_properties_t *, ur_queue_handle_t *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } + +UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t) { + return UR_RESULT_SUCCESS; +} diff --git a/unified-runtime/source/adapters/offload/queue.hpp b/unified-runtime/source/adapters/offload/queue.hpp index 8f887a9c3be01..a7106b2411939 100644 --- a/unified-runtime/source/adapters/offload/queue.hpp +++ b/unified-runtime/source/adapters/offload/queue.hpp @@ -14,6 +14,7 @@ #include #include "common.hpp" +#include "event.hpp" constexpr size_t OOO_QUEUE_POOL_SIZE = 32; @@ -38,7 +39,7 @@ struct ur_queue_handle_t_ : RefCounted { // Mutex guarding the offset and barrier for out of order queues std::mutex OooMutex; size_t QueueOffset; - ol_event_handle_t Barrier; + ur_event_handle_t Barrier; ol_device_handle_t OffloadDevice; ur_context_handle_t UrContext; ur_queue_flags_t Flags; @@ -54,7 +55,7 @@ struct ur_queue_handle_t_ : RefCounted { } if (auto Event = Barrier) { - if (auto Res = olWaitEvents(Slot, &Event, 1)) { + if (auto Res = olWaitEvents(Slot, &Event->OffloadEvent, 1)) { return Res; } } diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 5b4c8bd13bc50..145f78d0f90b9 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -173,7 +173,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnEventsWait = urEnqueueEventsWait; pDdiTable->pfnEventsWaitWithBarrier = urEnqueueEventsWaitWithBarrier; pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch; - pDdiTable->pfnMemBufferCopy = nullptr; + pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy; pDdiTable->pfnMemBufferCopyRect = nullptr; pDdiTable->pfnMemBufferFill = nullptr; pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap; @@ -189,7 +189,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnUSMFill = nullptr; pDdiTable->pfnUSMAdvise = nullptr; pDdiTable->pfnUSMMemcpy2D = urEnqueueUSMMemcpy2D; - pDdiTable->pfnUSMMemcpy = nullptr; + pDdiTable->pfnUSMMemcpy = urEnqueueUSMMemcpy; pDdiTable->pfnUSMPrefetch = nullptr; pDdiTable->pfnReadHostPipe = nullptr; pDdiTable->pfnWriteHostPipe = nullptr; @@ -221,7 +221,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable( pDdiTable->pfnCreate = urQueueCreate; pDdiTable->pfnCreateWithNativeHandle = urQueueCreateWithNativeHandle; pDdiTable->pfnFinish = urQueueFinish; - pDdiTable->pfnFlush = nullptr; + pDdiTable->pfnFlush = urQueueFlush; pDdiTable->pfnGetInfo = urQueueGetInfo; pDdiTable->pfnGetNativeHandle = urQueueGetNativeHandle; pDdiTable->pfnRelease = urQueueRelease; diff --git a/unified-runtime/source/adapters/offload/usm.cpp b/unified-runtime/source/adapters/offload/usm.cpp index 99f7931e9ddd7..f427689618d69 100644 --- a/unified-runtime/source/adapters/offload/usm.cpp +++ b/unified-runtime/source/adapters/offload/usm.cpp @@ -23,7 +23,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext, OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_HOST, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size}); return UR_RESULT_SUCCESS; } @@ -33,7 +34,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size}); return UR_RESULT_SUCCESS; } @@ -43,10 +45,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_MANAGED, size, ppMem)); - hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED); + hContext->AllocTypeMap.insert_or_assign( + *ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size}); return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) { +UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, + void *pMem) { + hContext->AllocTypeMap.erase(pMem); return offloadResultToUR(olMemFree(pMem)); }