diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index 67415a0de0..1ed7db7fac 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -153,116 +153,6 @@ ur_exp_command_buffer_command_handle_t_:: urKernelRelease(Kernel); } -/// Helper function for calculating work dimensions for kernels -ur_result_t calculateKernelWorkDimensions( - ur_kernel_handle_t Kernel, ur_device_handle_t Device, - ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], - uint32_t WorkDim, const size_t *GlobalWorkSize, - const size_t *LocalWorkSize) { - - UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); - // If LocalWorkSize is not provided then Kernel must be provided to query - // suggested group size. - UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); - - // New variable needed because GlobalWorkSize parameter might not be of size 3 - size_t GlobalWorkSize3D[3]{1, 1, 1}; - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); - - if (LocalWorkSize) { - WG[0] = ur_cast(LocalWorkSize[0]); - WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; - WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; - } else { - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D - // values do not fit to 32-bit that the API only supports currently. - bool SuggestGroupSize = true; - for (int I : {0, 1, 2}) { - if (GlobalWorkSize3D[I] > UINT32_MAX) { - SuggestGroupSize = false; - } - } - if (SuggestGroupSize) { - ZE2UR_CALL(zeKernelSuggestGroupSize, - (Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); - } else { - for (int I : {0, 1, 2}) { - // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is - // fully divisable with. Start with the max possible size in - // each dimension. - uint32_t GroupSize[] = { - Device->ZeDeviceComputeProperties->maxGroupSizeX, - Device->ZeDeviceComputeProperties->maxGroupSizeY, - Device->ZeDeviceComputeProperties->maxGroupSizeZ}; - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); - while (GlobalWorkSize3D[I] % GroupSize[I]) { - --GroupSize[I]; - } - if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { - logger::debug("calculateKernelWorkDimensions: can't find a WG size " - "suitable for global work size > UINT32_MAX"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - WG[I] = GroupSize[I]; - } - logger::debug("calculateKernelWorkDimensions: using computed WG " - "size = {{{}, {}, {}}}", - WG[0], WG[1], WG[2]); - } - } - - // TODO: assert if sizes do not fit into 32-bit? - switch (WorkDim) { - case 3: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); - ZeThreadGroupDimensions.groupCountZ = - ur_cast(GlobalWorkSize3D[2] / WG[2]); - break; - case 2: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); - WG[2] = 1; - break; - case 1: - ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); - WG[1] = WG[2] = 1; - break; - - default: - logger::error("calculateKernelWorkDimensions: unsupported work_dim"); - return UR_RESULT_ERROR_INVALID_VALUE; - } - - // Error handling for non-uniform group size case - if (GlobalWorkSize3D[0] != - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 1st dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[1] != - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 2nd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[2] != - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " - "is not a multiple of the group size in the 3rd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - - return UR_RESULT_SUCCESS; -} - /// Helper function for finding the Level Zero events associated with the /// commands in a command-buffer, each event is pointed to by a sync-point in /// the wait list. @@ -626,8 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } ZE2UR_CALL(zeKernelSetGlobalOffsetExp, - (Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1], - GlobalWorkOffset[2])); + (Kernel->getZeKernel(CommandBuffer->Device), GlobalWorkOffset[0], + GlobalWorkOffset[1], GlobalWorkOffset[2])); } // If there are any pending arguments set them now. @@ -640,7 +530,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( CommandBuffer->Device)); } ZE2UR_CALL(zeKernelSetArgumentValue, - (Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + (Kernel->getZeKernel(CommandBuffer->Device), Arg.Index, Arg.Size, + ZeHandlePtr)); } Kernel->PendingArguments.clear(); @@ -651,7 +542,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( ZeThreadGroupDimensions, WG, WorkDim, GlobalWorkSize, LocalWorkSize)); - ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2])); + ZE2UR_CALL(zeKernelSetGroupSize, + (Kernel->getZeKernel(CommandBuffer->Device), WG[0], WG[1], WG[2])); CommandBuffer->KernelsList.push_back(Kernel); // Increment the reference count of the Kernel and indicate that the Kernel @@ -691,7 +583,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( if (CommandBuffer->IsInOrderCmdList) { ZE2UR_CALL(zeCommandListAppendLaunchKernel, - (CommandBuffer->ZeCommandList, Kernel->ZeKernel, + (CommandBuffer->ZeCommandList, + Kernel->getZeKernel(CommandBuffer->Device), &ZeThreadGroupDimensions, nullptr, 0, nullptr)); logger::debug("calling zeCommandListAppendLaunchKernel()"); @@ -712,7 +605,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( } ZE2UR_CALL(zeCommandListAppendLaunchKernel, - (CommandBuffer->ZeCommandList, Kernel->ZeKernel, + (CommandBuffer->ZeCommandList, + Kernel->getZeKernel(CommandBuffer->Device), &ZeThreadGroupDimensions, LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data())); diff --git a/source/adapters/level_zero/kernel.cpp b/source/adapters/level_zero/kernel.cpp index 61cc247cd9..71babe7fd3 100644 --- a/source/adapters/level_zero/kernel.cpp +++ b/source/adapters/level_zero/kernel.cpp @@ -13,6 +13,117 @@ #include "ur_api.h" #include "ur_level_zero.hpp" +/// Helper function for calculating work dimensions for kernels +ur_result_t calculateKernelWorkDimensions( + ur_kernel_handle_t Kernel, ur_device_handle_t Device, + ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], + uint32_t WorkDim, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize) { + + UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); + // If LocalWorkSize is not provided then Kernel must be provided to query + // suggested group size. + UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); + + // New variable needed because GlobalWorkSize parameter might not be of size 3 + size_t GlobalWorkSize3D[3]{1, 1, 1}; + std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); + + if (LocalWorkSize) { + WG[0] = ur_cast(LocalWorkSize[0]); + WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; + WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; + } else { + // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D + // values do not fit to 32-bit that the API only supports currently. + bool SuggestGroupSize = true; + for (int I : {0, 1, 2}) { + if (GlobalWorkSize3D[I] > UINT32_MAX) { + SuggestGroupSize = false; + } + } + if (SuggestGroupSize) { + ZE2UR_CALL(zeKernelSuggestGroupSize, + (Kernel->getZeKernel(Device), GlobalWorkSize3D[0], + GlobalWorkSize3D[1], GlobalWorkSize3D[2], &WG[0], &WG[1], + &WG[2])); + } else { + for (int I : {0, 1, 2}) { + // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is + // fully divisable with. Start with the max possible size in + // each dimension. + uint32_t GroupSize[] = { + Device->ZeDeviceComputeProperties->maxGroupSizeX, + Device->ZeDeviceComputeProperties->maxGroupSizeY, + Device->ZeDeviceComputeProperties->maxGroupSizeZ}; + GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); + while (GlobalWorkSize3D[I] % GroupSize[I]) { + --GroupSize[I]; + } + if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { + logger::debug("calculateKernelWorkDimensions: can't find a WG size " + "suitable for global work size > UINT32_MAX"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + WG[I] = GroupSize[I]; + } + logger::debug("calculateKernelWorkDimensions: using computed WG " + "size = {{{}, {}, {}}}", + WG[0], WG[1], WG[2]); + } + } + + // TODO: assert if sizes do not fit into 32-bit? + switch (WorkDim) { + case 3: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + ur_cast(GlobalWorkSize3D[1] / WG[1]); + ZeThreadGroupDimensions.groupCountZ = + ur_cast(GlobalWorkSize3D[2] / WG[2]); + break; + case 2: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + ZeThreadGroupDimensions.groupCountY = + ur_cast(GlobalWorkSize3D[1] / WG[1]); + WG[2] = 1; + break; + case 1: + ZeThreadGroupDimensions.groupCountX = + ur_cast(GlobalWorkSize3D[0] / WG[0]); + WG[1] = WG[2] = 1; + break; + + default: + logger::error("calculateKernelWorkDimensions: unsupported work_dim"); + return UR_RESULT_ERROR_INVALID_VALUE; + } + + // Error handling for non-uniform group size case + if (GlobalWorkSize3D[0] != + size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 1st dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[1] != + size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 2nd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + if (GlobalWorkSize3D[2] != + size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { + logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " + "is not a multiple of the group size in the 3rd dimension"); + return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; + } + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ur_queue_handle_t Queue, ///< [in] handle of the queue object ur_kernel_handle_t Kernel, ///< [in] handle of the kernel object @@ -43,19 +154,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( *OutEvent ///< [in,out][optional] return an event object that identifies ///< this particular kernel execution instance. ) { - auto ZeDevice = Queue->Device->ZeDevice; + ze_kernel_handle_t ZeKernel = Kernel->getZeKernel(Queue->Device); - ze_kernel_handle_t ZeKernel{}; - if (Kernel->ZeKernelMap.empty()) { - ZeKernel = Kernel->ZeKernel; - } else { - auto It = Kernel->ZeKernelMap.find(ZeDevice); - if (It == Kernel->ZeKernelMap.end()) { - /* kernel and queue don't match */ - return UR_RESULT_ERROR_INVALID_QUEUE; - } - ZeKernel = It->second; + if (!ZeKernel) { + return UR_RESULT_ERROR_INVALID_QUEUE; } + // Lock automatically releases when this goes out of scope. std::scoped_lock Lock( Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex); @@ -87,109 +191,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; uint32_t WG[3]{}; - // New variable needed because GlobalWorkSize parameter might not be of size 3 - size_t GlobalWorkSize3D[3]{1, 1, 1}; - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); - if (LocalWorkSize) { - // L0 UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(LocalWorkSize[1] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(LocalWorkSize[2] < (std::numeric_limits::max)(), UR_RESULT_ERROR_INVALID_VALUE); - WG[0] = static_cast(LocalWorkSize[0]); - WG[1] = static_cast(LocalWorkSize[1]); - WG[2] = static_cast(LocalWorkSize[2]); - } else { - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize - // values do not fit to 32-bit that the API only supports currently. - bool SuggestGroupSize = true; - for (int I : {0, 1, 2}) { - if (GlobalWorkSize3D[I] > UINT32_MAX) { - SuggestGroupSize = false; - } - } - if (SuggestGroupSize) { - ZE2UR_CALL(zeKernelSuggestGroupSize, - (ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); - } else { - for (int I : {0, 1, 2}) { - // Try to find a I-dimension WG size that the GlobalWorkSize[I] is - // fully divisable with. Start with the max possible size in - // each dimension. - uint32_t GroupSize[] = { - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX, - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY, - Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ}; - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); - while (GlobalWorkSize3D[I] % GroupSize[I]) { - --GroupSize[I]; - } - - if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) { - logger::error("urEnqueueKernelLaunch: can't find a WG size " - "suitable for global work size > UINT32_MAX"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - WG[I] = GroupSize[I]; - } - logger::debug( - "urEnqueueKernelLaunch: using computed WG size = {{{}, {}, {}}}", - WG[0], WG[1], WG[2]); - } - } - - // TODO: assert if sizes do not fit into 32-bit? - - switch (WorkDim) { - case 3: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - static_cast(GlobalWorkSize3D[1] / WG[1]); - ZeThreadGroupDimensions.groupCountZ = - static_cast(GlobalWorkSize3D[2] / WG[2]); - break; - case 2: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - ZeThreadGroupDimensions.groupCountY = - static_cast(GlobalWorkSize3D[1] / WG[1]); - WG[2] = 1; - break; - case 1: - ZeThreadGroupDimensions.groupCountX = - static_cast(GlobalWorkSize3D[0] / WG[0]); - WG[1] = WG[2] = 1; - break; - - default: - logger::error("urEnqueueKernelLaunch: unsupported work_dim"); - return UR_RESULT_ERROR_INVALID_VALUE; } - // Error handling for non-uniform group size case - if (GlobalWorkSize3D[0] != - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { - logger::error("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 1st dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[1] != - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { - logger::error("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 2nd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } - if (GlobalWorkSize3D[2] != - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { - logger::debug("urEnqueueKernelLaunch: invalid work_dim. The range is not a " - "multiple of the group size in the 3rd dimension"); - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; - } + UR_CALL(calculateKernelWorkDimensions(Kernel, Queue->Device, + ZeThreadGroupDimensions, WG, WorkDim, + GlobalWorkSize, LocalWorkSize)); ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2])); @@ -384,55 +397,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate( const char *KernelName, ///< [in] pointer to null-terminated string. ur_kernel_handle_t *RetKernel ///< [out] pointer to handle of kernel object created. -) { + ) try { std::shared_lock Guard(Program->Mutex); if (Program->State != ur_program_handle_t_::state::Exe) { return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE; } - try { - ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program); - *RetKernel = reinterpret_cast(UrKernel); - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } + auto UrKernel = std::make_unique(true, Program); - for (auto It : Program->ZeModuleMap) { - auto ZeModule = It.second; + for (auto [ZeDevice, ZeModule] : Program->ZeModuleMap) { ZeStruct ZeKernelDesc; - ZeKernelDesc.flags = 0; ZeKernelDesc.pKernelName = KernelName; ze_kernel_handle_t ZeKernel; ZE2UR_CALL(zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel)); - - auto ZeDevice = It.first; - - // Store the kernel in the ZeKernelMap so the correct - // kernel can be retrieved later for a specific device - // where a queue is being submitted. - (*RetKernel)->ZeKernelMap[ZeDevice] = ZeKernel; - (*RetKernel)->ZeKernels.push_back(ZeKernel); - - // If the device used to create the module's kernel is a root-device - // then store the kernel also using the sub-devices, since application - // could submit the root-device's kernel to a sub-device's queue. - uint32_t SubDevicesCount = 0; - zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, nullptr); - std::vector ZeSubDevices(SubDevicesCount); - zeDeviceGetSubDevices(ZeDevice, &SubDevicesCount, ZeSubDevices.data()); - for (auto ZeSubDevice : ZeSubDevices) { - (*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel; - } + UrKernel->ZeKernels.emplace_back(ZeDevice, ZeKernel); } - (*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second; + UR_CALL(UrKernel->initialize()); - UR_CALL((*RetKernel)->initialize()); + *RetKernel = reinterpret_cast(UrKernel.release()); return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); } UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue( @@ -461,16 +449,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue( } std::scoped_lock Guard(Kernel->Mutex); - if (Kernel->ZeKernelMap.empty()) { - auto ZeKernel = Kernel->ZeKernel; + for (auto [_, ZeKernel] : Kernel->ZeKernels) { ZE2UR_CALL(zeKernelSetArgumentValue, (ZeKernel, ArgIndex, ArgSize, PArgValue)); - } else { - for (auto It : Kernel->ZeKernelMap) { - auto ZeKernel = It.second; - ZE2UR_CALL(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); - } } return UR_RESULT_SUCCESS; @@ -530,10 +511,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo( try { uint32_t Size; ZE2UR_CALL(zeKernelGetSourceAttributes, - (Kernel->ZeKernel, &Size, nullptr)); + (Kernel->ZeKernels[0].second, &Size, nullptr)); char *attributes = new char[Size]; ZE2UR_CALL(zeKernelGetSourceAttributes, - (Kernel->ZeKernel, &Size, &attributes)); + (Kernel->ZeKernels[0].second, &Size, &attributes)); auto Res = ReturnValue(attributes); delete[] attributes; return Res; @@ -586,14 +567,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetGroupInfo( ZeStruct kernelProperties; kernelProperties.pNext = &workGroupProperties; - // Set the Kernel to use as the ZeKernel initally for native handle support. - // This makes the assumption that this device is the same device where this - // kernel was created. - auto ZeKernelDevice = Kernel->ZeKernel; - auto It = Kernel->ZeKernelMap.find(Device->ZeDevice); - if (It != Kernel->ZeKernelMap.end()) { - ZeKernelDevice = Kernel->ZeKernelMap[Device->ZeDevice]; - } + + auto ZeKernelDevice = Kernel->getZeKernel(Device); if (ZeKernelDevice) { auto ZeResult = ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernelDevice, &kernelProperties)); @@ -680,14 +655,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease( auto KernelProgram = Kernel->Program; if (Kernel->OwnNativeHandle) { - for (auto &ZeKernel : Kernel->ZeKernels) { + for (auto &[_, ZeKernel] : Kernel->ZeKernels) { auto ZeResult = ZE_CALL_NOCHECK(zeKernelDestroy, (ZeKernel)); // Gracefully handle the case that L0 was already unloaded. if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED) return ze2urResult(ZeResult); } } - Kernel->ZeKernelMap.clear(); + if (IndirectAccessTrackingEnabled) { UR_CALL(urContextRelease(KernelProgram->Context)); } @@ -728,7 +703,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetExecInfo( std::ignore = PropSize; std::ignore = Properties; - auto ZeKernel = Kernel->ZeKernel; + assert(Kernel->ZeKernels.size()); + auto ZeKernel = Kernel->ZeKernels[0].second; std::scoped_lock Guard(Kernel->Mutex); if (PropName == UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS && *(static_cast(PropValue)) == true) { @@ -771,7 +747,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler( ) { std::ignore = Properties; std::scoped_lock Guard(Kernel->Mutex); - ZE2UR_CALL(zeKernelSetArgumentValue, (Kernel->ZeKernel, ArgIndex, + assert(Kernel->ZeKernels.size()); + ZE2UR_CALL(zeKernelSetArgumentValue, (Kernel->ZeKernels[0].second, ArgIndex, sizeof(void *), &ArgValue->ZeSampler)); return UR_RESULT_SUCCESS; @@ -822,7 +799,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle( ) { std::shared_lock Guard(Kernel->Mutex); - *NativeKernel = reinterpret_cast(Kernel->ZeKernel); + assert(Kernel->ZeKernels.size()); + *NativeKernel = + reinterpret_cast(Kernel->ZeKernels[0].second); return UR_RESULT_SUCCESS; } @@ -873,15 +852,17 @@ ur_result_t ur_kernel_handle_t_::initialize() { // Set up how to obtain kernel properties when needed. ZeKernelProperties.Compute = [this](ze_kernel_properties_t &Properties) { - ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernel, &Properties)); + assert(ZeKernels.size()); + ZE_CALL_NOCHECK(zeKernelGetProperties, (ZeKernels[0].second, &Properties)); }; // Cache kernel name. ZeKernelName.Compute = [this](std::string &Name) { + assert(ZeKernels.size()); size_t Size = 0; - ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernel, &Size, nullptr)); + ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernels[0].second, &Size, nullptr)); char *KernelName = new char[Size]; - ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernel, &Size, KernelName)); + ZE_CALL_NOCHECK(zeKernelGetName, (ZeKernels[0].second, &Size, KernelName)); Name = KernelName; delete[] KernelName; }; diff --git a/source/adapters/level_zero/kernel.hpp b/source/adapters/level_zero/kernel.hpp index 1cc146d262..c017610ef2 100644 --- a/source/adapters/level_zero/kernel.hpp +++ b/source/adapters/level_zero/kernel.hpp @@ -15,16 +15,36 @@ struct ur_kernel_handle_t_ : _ur_object { ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program) - : Context{nullptr}, Program{Program}, ZeKernel{nullptr}, - SubmissionsCount{0}, MemAllocs{} { + : Context{nullptr}, Program{Program}, SubmissionsCount{0}, MemAllocs{} { OwnNativeHandle = OwnZeHandle; } ur_kernel_handle_t_(ze_kernel_handle_t Kernel, bool OwnZeHandle, ur_context_handle_t Context) - : Context{Context}, Program{nullptr}, ZeKernel{Kernel}, - SubmissionsCount{0}, MemAllocs{} { + : Context{Context}, Program{nullptr}, SubmissionsCount{0}, MemAllocs{} { OwnNativeHandle = OwnZeHandle; + createdFromNativeHandle = true; + ZeKernels.push_back({nullptr, Kernel}); + } + + ze_kernel_handle_t getZeKernel(ur_device_handle_t Device) { + assert(Device); + + if (createdFromNativeHandle) { + assert(ZeKernels.size() == 1); + assert(ZeKernels[0].first == nullptr); + + return ZeKernels.at(0).second; + } + + auto ZeDevice = + Device->RootDevice ? Device->RootDevice->ZeDevice : Device->ZeDevice; + + for (auto &K : ZeKernels) { + if (K.first == ZeDevice) + return K.second; + } + return nullptr; } // Keep the program of the kernel. @@ -33,17 +53,13 @@ struct ur_kernel_handle_t_ : _ur_object { // Keep the program of the kernel. ur_program_handle_t Program; - // Level Zero function handle. - ze_kernel_handle_t ZeKernel; + // Vector of L0 kernels for all devices. + std::vector> ZeKernels; - // Map of L0 kernels created for all the devices for which a UR Program - // has been built. It may contain duplicated kernel entries for a root - // device and its sub-devices. - std::unordered_map ZeKernelMap; - - // Vector of L0 kernels. Each entry is unique, so this is used for - // destroying the kernels instead of ZeKernelMap - std::vector ZeKernels; + // Whether this kernel has been created from a native handle. + // In such case, ZeKernels contains only one element with device handle + // set to nullptr. + bool createdFromNativeHandle{false}; // Counter to track the number of submissions of the kernel. // When this value is zero, it means that kernel is not submitted for an @@ -107,3 +123,10 @@ struct ur_kernel_handle_t_ : _ur_object { ZeCache> ZeKernelProperties; ZeCache ZeKernelName; }; + +/// Helper function for calculating work dimensions for kernels +ur_result_t calculateKernelWorkDimensions( + ur_kernel_handle_t Kernel, ur_device_handle_t Device, + ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], + uint32_t WorkDim, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize);