diff --git a/unified-runtime/source/adapters/hip/CMakeLists.txt b/unified-runtime/source/adapters/hip/CMakeLists.txt index d0674ad504d4e..e6a3a4ed82769 100644 --- a/unified-runtime/source/adapters/hip/CMakeLists.txt +++ b/unified-runtime/source/adapters/hip/CMakeLists.txt @@ -203,6 +203,12 @@ else() message(FATAL_ERROR "Unspecified UR HIP platform please set UR_HIP_PLATFORM to 'AMD' or 'NVIDIA'") endif() +if(UMF_ENABLE_POOL_TRACKING) + target_compile_definitions(${TARGET_NAME} PRIVATE UMF_ENABLE_POOL_TRACKING) +else() + message(WARNING "HIP adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them") +endif() + target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/../../" ) diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index e926ad426856e..5f0ebe7483ec4 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -37,7 +37,13 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc, return USMHostAllocImpl(ppMem, hContext, /* flags */ 0, size, alignment); } - return umfPoolMallocHelper(hPool, ppMem, size, alignment); + auto UMFPool = hPool->HostMemPool.get(); + *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + if (*ppMem == nullptr) { + auto umfErr = umfPoolGetLastAllocationError(UMFPool); + return umf::umf2urResult(umfErr); + } + return UR_RESULT_SUCCESS; } /// USM: Implements USM device allocations using a normal HIP device pointer @@ -54,7 +60,13 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice, alignment); } - return umfPoolMallocHelper(hPool, ppMem, size, alignment); + auto UMFPool = hPool->DeviceMemPool.get(); + *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + if (*ppMem == nullptr) { + auto umfErr = umfPoolGetLastAllocationError(UMFPool); + return umf::umf2urResult(umfErr); + } + return UR_RESULT_SUCCESS; } /// USM: Implements USM Shared allocations using HIP Managed Memory @@ -71,7 +83,13 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice, /*device flags*/ 0, size, alignment); } - return umfPoolMallocHelper(hPool, ppMem, size, alignment); + auto UMFPool = hPool->SharedMemPool.get(); + *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + if (*ppMem == nullptr) { + auto umfErr = umfPoolGetLastAllocationError(UMFPool); + return umf::umf2urResult(umfErr); + } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL @@ -330,15 +348,25 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc) : Context(Context) { - if (PoolDesc) { - if (auto *Limits = find_stype_node(PoolDesc)) { + + const void *pNext = PoolDesc->pNext; + while (pNext != nullptr) { + const ur_base_desc_t *BaseDesc = static_cast(pNext); + switch (BaseDesc->stype) { + case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: { + const ur_usm_pool_limits_desc_t *Limits = + reinterpret_cast(BaseDesc); for (auto &config : DisjointPoolConfigs.Configs) { config.MaxPoolableSize = Limits->maxPoolableSize; config.SlabMinSize = Limits->minDriverAllocSize; } - } else { + break; + } + default: { throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT); } + } + pNext = BaseDesc->pNext; } auto MemProvider = @@ -468,17 +496,6 @@ bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr) { reinterpret_cast(*ResultPtr) % Alignment == 0; } -ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem, - size_t size, uint32_t alignment) { - auto UMFPool = hPool->DeviceMemPool.get(); - *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); - if (*ppMem == nullptr) { - auto umfErr = umfPoolGetLastAllocationError(UMFPool); - return umf::umf2urResult(umfErr); - } - return UR_RESULT_SUCCESS; -} - UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreateExp(ur_context_handle_t, ur_device_handle_t, ur_usm_pool_desc_t *, diff --git a/unified-runtime/source/adapters/hip/usm.hpp b/unified-runtime/source/adapters/hip/usm.hpp index 2149ac26ba046..04b2b38a58836 100644 --- a/unified-runtime/source/adapters/hip/usm.hpp +++ b/unified-runtime/source/adapters/hip/usm.hpp @@ -140,6 +140,3 @@ ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t Context, bool checkUSMAlignment(uint32_t &alignment, const ur_usm_desc_t *pUSMDesc); bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr); - -ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem, - size_t size, uint32_t alignment); diff --git a/unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp b/unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp index eda4dc79449b2..ac9912271ed5b 100644 --- a/unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp +++ b/unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp @@ -29,7 +29,7 @@ TEST_P(urUSMPoolCreateTest, Success) { } TEST_P(urUSMPoolCreateTest, SuccessWithFlag) { - UUR_KNOWN_FAILURE_ON(uur::CUDA{}); + UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{}); ur_usm_pool_desc_t pool_desc{UR_STRUCTURE_TYPE_USM_POOL_DESC, nullptr, UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK};