From 6835607f718c625f3d7d939168655ed8c608840c Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Tue, 14 May 2024 19:54:00 +0000 Subject: [PATCH] [L0] Optimize kernel lookup in enqueue Change unordered_map to vector as number of entries is expected to be low (1 in common case). Also, do not store mapping between subdevices and the kernel handles. Instead, just look up kernel based on RootDevice handle for a sub device. Additionally, use unique_ptr to store kernel_handle to avoid memory leaks in case zeKernelCreate fails. --- source/adapters/level_zero/command_buffer.cpp | 126 +------ source/adapters/level_zero/kernel.cpp | 309 ++++++++---------- source/adapters/level_zero/kernel.hpp | 51 ++- 3 files changed, 192 insertions(+), 294 deletions(-) diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index 67415a0de0..1ed7db7fac 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -153,116 +153,6 @@ ur_exp_command_buffer_command_handle_t_:: urKernelRelease(Kernel); } -/// Helper function for calculating work dimensions for kernels -ur_result_t calculateKernelWorkDimensions( - ur_kernel_handle_t Kernel, ur_device_handle_t Device, - ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], - uint32_t WorkDim, const size_t *GlobalWorkSize, - const size_t *LocalWorkSize) { - - UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); - // If LocalWorkSize is not provided then Kernel must be provided to query - // suggested group size. - UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); - - // New variable needed because GlobalWorkSize parameter might not be of size 3 - size_t GlobalWorkSize3D[3]{1, 1, 1}; - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); - - if (LocalWorkSize) { - WG[0] = ur_cast(LocalWorkSize[0]); - WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; - WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; - } else { - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D - // values do not fit to 32-bit that the API only supports currently. - bool SuggestGroupSize = true; - for (int I : {0, 1, 2}) { - if (GlobalWorkSize3D[I] > UINT32_MAX) { - SuggestGroupSize = false; - } - } - if (SuggestGroupSize) { - ZE2UR_CALL(zeKernelSuggestGroupSize, - (Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); - } else { - for (int I : {0, 1, 2}) { - // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is - // fully divisable with. Start with the max possible size in - // each dimension. - uint32_t GroupSize[] = { - Device->ZeDeviceComputeProperties->maxGroupSizeX, - Device->ZeDeviceComputeProperties->maxGroupSizeY, - Device->ZeDeviceComputeProperties->maxGroupSizeZ}; - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); - while (GlobalWorkSize3D[I] % GroupSize[I]) { - --GroupSize[I]; - } - if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { - logger::debug("calculateKernelWorkDimensions: can't find a WG size " - "suitable for global work size > UINT32_MAX"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - WG[I] = GroupSize[I]; - } - logger::debug("calculateKernelWorkDimensions: using computed WG " - "size = {{{}, {}, {}}}", - WG[0], WG[1], WG[2]); - } - } - - // TODO: assert if sizes do not fit into 32-bit? - switch (WorkDim) { - case 3: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); - ZeThreadGroupDimensions.groupCountZ = - ur_cast(GlobalWorkSize3D[2] / WG[2]); - break; - case 2: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); - WG[2] = 1; - break; - case 1: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - WG[1] = WG[2] = 1; - break; - - default: - logger::error("calculateKernelWorkDimensions: unsupported work_dim"); - return UR_RESULT_ERROR_INVALID_VALUE; - } - - // Error handling for non-uniform group size case - if (GlobalWorkSize3D[0] != - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 1st dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[1] != - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 2nd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[2] != - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 3rd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - - return UR_RESULT_SUCCESS; -} - /// Helper function for finding the Level Zero events associated with the /// commands in a command-buffer, each event is pointed to by a sync-point in /// the wait list. @@ -626,8 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } ZE2UR_CALL(zeKernelSetGlobalOffsetExp, - (Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1], - GlobalWorkOffset[2])); + (Kernel->getZeKernel(CommandBuffer->Device), GlobalWorkOffset[0], + GlobalWorkOffset[1], GlobalWorkOffset[2])); } // If there are any pending arguments set them now. @@ -640,7 +530,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( CommandBuffer->Device)); } ZE2UR_CALL(zeKernelSetArgumentValue, - (Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + (Kernel->getZeKernel(CommandBuffer->Device), Arg.Index, Arg.Size, + ZeHandlePtr)); } Kernel->PendingArguments.clear(); @@ -651,7 +542,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( ZeThreadGroupDimensions, WG, WorkDim, GlobalWorkSize, LocalWorkSize)); - ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2])); + ZE2UR_CALL(zeKernelSetGroupSize, + (Kernel->getZeKernel(CommandBuffer->Device), WG[0], WG[1], WG[2])); CommandBuffer->KernelsList.push_back(Kernel); // Increment the reference count of the Kernel and indicate that the Kernel @@ -691,7 +583,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( if (CommandBuffer->IsInOrderCmdList) { ZE2UR_CALL(zeCommandListAppendLaunchKernel, - (CommandBuffer->ZeCommandList, Kernel->ZeKernel, + (CommandBuffer->ZeCommandList, + Kernel->getZeKernel(CommandBuffer->Device), &ZeThreadGroupDimensions, nullptr, 0, nullptr)); logger::debug("calling zeCommandListAppendLaunchKernel()"); @@ -712,7 +605,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } ZE2UR_CALL(zeCommandListAppendLaunchKernel, - (CommandBuffer->ZeCommandList, Kernel->ZeKernel, + (CommandBuffer->ZeCommandList, + Kernel->getZeKernel(CommandBuffer->Device), &ZeThreadGroupDimensions, LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data())); diff --git a/source/adapters/level_zero/kernel.cpp b/source/adapters/level_zero/kernel.cpp index 61cc247cd9..71babe7fd3 100644 --- a/source/adapters/level_zero/kernel.cpp +++ b/source/adapters/level_zero/kernel.cpp @@ -13,6 +13,117 @@ #include "ur_api.h" #include "ur_level_zero.hpp" +/// Helper function for calculating work dimensions for kernels +ur_result_t calculateKernelWorkDimensions( + ur_kernel_handle_t Kernel, ur_device_handle_t Device, + ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], + uint32_t WorkDim, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize) { + + UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); + // If LocalWorkSize is not provided then Kernel must be provided to query + // suggested group size. + UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); + + // New variable needed because GlobalWorkSize parameter might not be of size 3 + size_t GlobalWorkSize3D[3]{1, 1, 1}; + std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); + + if (LocalWorkSize) { + WG[0] = ur_cast(LocalWorkSize[0]); + WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; + WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; + } else { + // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D + // values do not fit to 32-bit that the API only supports currently. + bool SuggestGroupSize = true; + for (int I : {0, 1, 2}) { + if (GlobalWorkSize3D[I] > UINT32_MAX) { + SuggestGroupSize = false; + } + } + if (SuggestGroupSize) { + ZE2UR_CALL(zeKernelSuggestGroupSize, + (Kernel->getZeKernel(Device), GlobalWorkSize3D[0], + GlobalWorkSize3D[1], GlobalWorkSize3D[2], &WG[0], &WG[1], + &WG[2])); + } else { + for (int I : {0, 1, 2}) { + // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is + // fully divisable with. Start with the max possible size in + // each dimension. + uint32_t GroupSize[] = { + Device->ZeDeviceComputeProperties->maxGroupSizeX, + Device->ZeDeviceComputeProperties->maxGroupSizeY, + Device->ZeDeviceComputeProperties->maxGroupSizeZ}; + GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); + while (GlobalWorkSize3D[I] % GroupSize[I]) { + --GroupSize[I]; + } + if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { + logger::debug("calculateKernelWorkDimensions: can't find a WG size " + "suitable for global work size > UINT32_MAX"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + WG[I] = GroupSize[I]; + } + logger::debug("calculateKernelWorkDimensions: using computed WG " + "size = {{{}, {}, {}}}", + WG[0], WG[1], WG[2]); + } + } + + // TODO: assert if sizes do not fit into 32-bit? + switch (WorkDim) { + case 3: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + ur_cast(GlobalWorkSize3D[1] / WG[1]); + ZeThreadGroupDimensions.groupCountZ = + ur_cast(GlobalWorkSize3D[2] / WG[2]); + break; + case 2: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + ur_cast(GlobalWorkSize3D[1] / WG[1]); + WG[2] = 1; + break; + case 1: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + WG[1] = WG[2] = 1; + break; + + default: + logger::error("calculateKernelWorkDimensions: unsupported work_dim"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + // Error handling for non-uniform group size case + if (GlobalWorkSize3D[0] != + size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 1st dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[1] != + size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 2nd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[2] != + size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 3rd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ur_queue_handle_t Queue, ///< [in] handle of the queue object ur_kernel_handle_t Kernel, ///< [in] handle of the kernel object @@ -43,19 +154,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( *OutEvent ///< [in,out][optional] return an event object that identifies ///< this particular kernel execution instance. ) { - auto ZeDevice = Queue->Device->ZeDevice; + ze_kernel_handle_t ZeKernel = Kernel->getZeKernel(Queue->Device); - ze_kernel_handle_t ZeKernel{}; - if (Kernel->ZeKernelMap.empty()) { - ZeKernel = Kernel->ZeKernel; - } else { - auto It = Kernel->ZeKernelMap.find(ZeDevice); - if (It == Kernel->ZeKernelMap.end()) { - /* kernel and queue don't match */ - return UR_RESULT_ERROR_INVALID_QUEUE; - } - ZeKernel = It->second; + if (!ZeKernel) { + return UR_RESULT_ERROR_INVALID_QUEUE; } + // Lock automatically releases when this goes out of scope. std::scoped_lock Lock( Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex); @@ -87,109 +191,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; uint32_t WG[3]{}; - // New variable needed because GlobalWorkSize parameter might not be of size 3 - size_t GlobalWorkSize3D[3]{1, 1, 1}; - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); - if (LocalWorkSize) { - // L0 UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(LocalWorkSize[1] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(LocalWorkSize[2] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); - WG[0] = static_cast(LocalWorkSize[0]); - WG[1] = static_cast(LocalWorkSize[1]); - WG[2] = static_cast(LocalWorkSize[2]); - } else { - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize - // values do not fit to 32-bit that the API only supports currently. - bool SuggestGroupSize = true; - for (int I : {0, 1, 2}) { - if (GlobalWorkSize3D[I] > UINT32_MAX) { - SuggestGroupSize = false; - } - } - if (SuggestGroupSize) { - ZE2UR_CALL(zeKernelSuggestGroupSize, - (ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); - } else { - for (int I : {0, 1, 2}) { - // Try to find a I-dimension WG size that the GlobalWorkSize[I] is - // fully divisable with. Start with the max possible size in - // each dimension. - uint32_t GroupSize[] = { - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX, - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY, - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ}; - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); - while (GlobalWorkSize3D[I] % GroupSize[I]) { - --GroupSize[I]; - } - - if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) { - logger::error("urEnqueueKernelLaunch: can't find a WG size " - "suitable for global work size > UINT32_MAX"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - WG[I] = GroupSize[I]; - } - logger::debug( - "urEnqueueKernelLaunch: using computed WG size = {{{}, {}, {}}}", - WG[0], WG[1], WG[2]); - } - } - - // TODO: assert if sizes do not fit into 32-bit? - - switch (WorkDim) { - case 3: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - static_cast(GlobalWorkSize3D[1] / WG[1]); - ZeThreadGroupDimensions.groupCountZ = - static_cast(GlobalWorkSize3D[2] / WG[2]); - break; - case 2: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - static_cast(GlobalWorkSize3D[1] / WG[1]); - WG[2] = 1; - break; - case 1: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - WG[1] = WG[2] = 1; - break; - - default: - logger::error("urEnqueueKernelLaunch: unsupported work_dim"); - return UR_RESULT_ERROR_INVALID_VALUE; } - // Error handling for non-uniform group size case - if (GlobalWorkSize3D[0] != - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { - logger::error("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 1st dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[1] != - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { - logger::error("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 2nd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[2] != - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { - logger::debug("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 3rd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } + UR_CALL(calculateKernelWorkDimensions(Kernel, Queue->Device, + ZeThreadGroupDimensions, WG, WorkDim, + GlobalWorkSize, LocalWorkSize)); ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2])); @@ -384,55 +397,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate( const char *KernelName, ///< [in] pointer to null-terminated string. ur_kernel_handle_t *RetKernel ///< [out] pointer to handle of kernel object created. -) { + ) try { std::shared_lock Guard(Program->Mutex); if (Program->State != ur_program_handle_t_::state::Exe) { return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE; } - try { - ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program); - *RetKernel = reinterpret_cast(UrKernel); - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + auto UrKernel = std::make_unique(true, Program); - for (auto It : Program->ZeModuleMap) { - auto ZeModule = It.second; + for (auto [ZeDevice, ZeModule] : Program->ZeModuleMap) { ZeStruct ZeKernelDesc; - ZeKernelDesc.flags = 0; ZeKernelDesc.pKernelName = KernelName; ze_kernel_handle_t ZeKernel; ZE2UR_CALL(zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel)); - - auto ZeDevice = It.first; - - // Store the kernel in the ZeKernelMap so the correct - // kernel can be retrieved later for a specific device - // where a queue is being submitted. - (*RetKernel)->ZeKernelMap[ZeDevice] = ZeKernel; - (*RetKernel)->ZeKernels.push_back(ZeKernel); - - // If the device used to create the module's kernel is a root-device - // then store the kernel also using the sub-devices, since application - // could submit the root-device's kernel to a sub-device's queue. - uint32_t SubDevicesCount = 0; - zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, nullptr); - std::vector ZeSubDevices(SubDevicesCount); - zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, ZeSubDevices.data()); - for (auto ZeSubDevice : ZeSubDevices) { - (*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel; - } + UrKernel->ZeKernels.emplace_back(ZeDevice, ZeKernel); } - (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second; + UR_CALL(UrKernel->initialize()); - UR_CALL((*RetKernel)->initialize()); + *RetKernel = reinterpret_cast(UrKernel.release()); return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); } UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue( @@ -461,16 +449,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue( } std::scoped_lock Guard(Kernel->Mutex); - if (Kernel->ZeKernelMap.empty()) { - auto ZeKernel = Kernel->ZeKernel; + for (auto [_, ZeKernel] : Kernel->ZeKernels) { ZE2UR_CALL(zeKernelSetArgumentValue, (ZeKernel, ArgIndex, ArgSize, PArgValue)); - } else { - for (auto It : Kernel->ZeKernelMap) { - auto ZeKernel = It.second; - ZE2UR_CALL(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); - } } return UR_RESULT_SUCCESS; @@ -530,10 +511,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo( try { uint32_t Size; ZE2UR_CALL(zeKernelGetSourceAttributes, - (Kernel->ZeKernel, &Size, nullptr)); + (Kernel->ZeKernels[0].second, &Size, nullptr)); char *attributes = new char[Size]; ZE2UR_CALL(zeKernelGetSourceAttributes, - (Kernel->ZeKernel, &Size, &attributes)); + (Kernel->ZeKernels[0].second, &Size, &attributes)); auto Res = ReturnValue(attributes); delete[] attributes; return Res; @@ -586,14 +567,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetGroupInfo( ZeStruct kernelProperties; kernelProperties.pNext = &workGroupProperties; - // Set the Kernel to use as the ZeKernel initally for native handle support. - // This makes the assumption that this device is the same device where this - // kernel was created. - auto ZeKernelDevice = Kernel->ZeKernel; - auto It = Kernel->ZeKernelMap.find(Device->ZeDevice); - if (It != Kernel->ZeKernelMap.end()) { - ZeKernelDevice = Kernel->ZeKernelMap[Device->ZeDevice]; - } + + auto ZeKernelDevice = Kernel->getZeKernel(Device); if (ZeKernelDevice) { auto ZeResult = ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernelDevice, &kernelProperties)); @@ -680,14 +655,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease( auto KernelProgram = Kernel->Program; if (Kernel->OwnNativeHandle) { - for (auto &ZeKernel : Kernel->ZeKernels) { + for (auto &[_, ZeKernel] : Kernel->ZeKernels) { auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (ZeKernel)); // Gracefully handle the case that L0 was already unloaded. if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED) return ze2urResult(ZeResult); } } - Kernel->ZeKernelMap.clear(); + if (IndirectAccessTrackingEnabled) { UR_CALL(urContextRelease(KernelProgram->Context)); } @@ -728,7 +703,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo( std::ignore = PropSize; std::ignore = Properties; - auto ZeKernel = Kernel->ZeKernel; + assert(Kernel->ZeKernels.size()); + auto ZeKernel = Kernel->ZeKernels[0].second; std::scoped_lock Guard(Kernel->Mutex); if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS && *(static_cast(PropValue)) == true) { @@ -771,7 +747,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler( ) { std::ignore = Properties; std::scoped_lock Guard(Kernel->Mutex); - ZE2UR_CALL(zeKernelSetArgumentValue, (Kernel->ZeKernel, ArgIndex, + assert(Kernel->ZeKernels.size()); + ZE2UR_CALL(zeKernelSetArgumentValue, (Kernel->ZeKernels[0].second, ArgIndex, sizeof(void *), &ArgValue->ZeSampler)); return UR_RESULT_SUCCESS; @@ -822,7 +799,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( ) { std::shared_lock Guard(Kernel->Mutex); - *NativeKernel = reinterpret_cast(Kernel->ZeKernel); + assert(Kernel->ZeKernels.size()); + *NativeKernel = + reinterpret_cast(Kernel->ZeKernels[0].second); return UR_RESULT_SUCCESS; } @@ -873,15 +852,17 @@ ur_result_t ur_kernel_handle_t_::initialize() { // Set up how to obtain kernel properties when needed. ZeKernelProperties.Compute = [this](ze_kernel_properties_t &Properties) { - ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernel, &Properties)); + assert(ZeKernels.size()); + ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernels[0].second, &Properties)); }; // Cache kernel name. ZeKernelName.Compute = [this](std::string &Name) { + assert(ZeKernels.size()); size_t Size = 0; - ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernel, &Size, nullptr)); + ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernels[0].second, &Size, nullptr)); char *KernelName = new char[Size]; - ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernel, &Size, KernelName)); + ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernels[0].second, &Size, KernelName)); Name = KernelName; delete[] KernelName; }; diff --git a/source/adapters/level_zero/kernel.hpp b/source/adapters/level_zero/kernel.hpp index 1cc146d262..c017610ef2 100644 --- a/source/adapters/level_zero/kernel.hpp +++ b/source/adapters/level_zero/kernel.hpp @@ -15,16 +15,36 @@ struct ur_kernel_handle_t_ : _ur_object { ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program) - : Context{nullptr}, Program{Program}, ZeKernel{nullptr}, - SubmissionsCount{0}, MemAllocs{} { + : Context{nullptr}, Program{Program}, SubmissionsCount{0}, MemAllocs{} { OwnNativeHandle = OwnZeHandle; } ur_kernel_handle_t_(ze_kernel_handle_t Kernel, bool OwnZeHandle, ur_context_handle_t Context) - : Context{Context}, Program{nullptr}, ZeKernel{Kernel}, - SubmissionsCount{0}, MemAllocs{} { + : Context{Context}, Program{nullptr}, SubmissionsCount{0}, MemAllocs{} { OwnNativeHandle = OwnZeHandle; + createdFromNativeHandle = true; + ZeKernels.push_back({nullptr, Kernel}); + } + + ze_kernel_handle_t getZeKernel(ur_device_handle_t Device) { + assert(Device); + + if (createdFromNativeHandle) { + assert(ZeKernels.size() == 1); + assert(ZeKernels[0].first == nullptr); + + return ZeKernels.at(0).second; + } + + auto ZeDevice = + Device->RootDevice ? Device->RootDevice->ZeDevice : Device->ZeDevice; + + for (auto &K : ZeKernels) { + if (K.first == ZeDevice) + return K.second; + } + return nullptr; } // Keep the program of the kernel. @@ -33,17 +53,13 @@ struct ur_kernel_handle_t_ : _ur_object { // Keep the program of the kernel. ur_program_handle_t Program; - // Level Zero function handle. - ze_kernel_handle_t ZeKernel; + // Vector of L0 kernels for all devices. + std::vector> ZeKernels; - // Map of L0 kernels created for all the devices for which a UR Program - // has been built. It may contain duplicated kernel entries for a root - // device and its sub-devices. - std::unordered_map ZeKernelMap; - - // Vector of L0 kernels. Each entry is unique, so this is used for - // destroying the kernels instead of ZeKernelMap - std::vector ZeKernels; + // Whether this kernel has been created from a native handle. + // In such case, ZeKernels contains only one element with device handle + // set to nullptr. + bool createdFromNativeHandle{false}; // Counter to track the number of submissions of the kernel. // When this value is zero, it means that kernel is not submitted for an @@ -107,3 +123,10 @@ struct ur_kernel_handle_t_ : _ur_object { ZeCache> ZeKernelProperties; ZeCache ZeKernelName; }; + +/// Helper function for calculating work dimensions for kernels +ur_result_t calculateKernelWorkDimensions( + ur_kernel_handle_t Kernel, ur_device_handle_t Device, + ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], + uint32_t WorkDim, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize);