Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 10 additions & 116 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,116 +153,6 @@ ur_exp_command_buffer_command_handle_t_::
urKernelRelease(Kernel);
}

/// Helper function for calculating work dimensions for kernels
ur_result_t calculateKernelWorkDimensions(
ur_kernel_handle_t Kernel, ur_device_handle_t Device,
ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3],
uint32_t WorkDim, const size_t *GlobalWorkSize,
const size_t *LocalWorkSize) {

UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE);
// If LocalWorkSize is not provided then Kernel must be provided to query
// suggested group size.
UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE);

// New variable needed because GlobalWorkSize parameter might not be of size 3
size_t GlobalWorkSize3D[3]{1, 1, 1};
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);

if (LocalWorkSize) {
WG[0] = ur_cast<uint32_t>(LocalWorkSize[0]);
WG[1] = WorkDim >= 2 ? ur_cast<uint32_t>(LocalWorkSize[1]) : 1;
WG[2] = WorkDim == 3 ? ur_cast<uint32_t>(LocalWorkSize[2]) : 1;
} else {
// We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D
// values do not fit to 32-bit that the API only supports currently.
bool SuggestGroupSize = true;
for (int I : {0, 1, 2}) {
if (GlobalWorkSize3D[I] > UINT32_MAX) {
SuggestGroupSize = false;
}
}
if (SuggestGroupSize) {
ZE2UR_CALL(zeKernelSuggestGroupSize,
(Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1],
GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2]));
} else {
for (int I : {0, 1, 2}) {
// Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is
// fully divisable with. Start with the max possible size in
// each dimension.
uint32_t GroupSize[] = {
Device->ZeDeviceComputeProperties->maxGroupSizeX,
Device->ZeDeviceComputeProperties->maxGroupSizeY,
Device->ZeDeviceComputeProperties->maxGroupSizeZ};
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]);
while (GlobalWorkSize3D[I] % GroupSize[I]) {
--GroupSize[I];
}
if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) {
logger::debug("calculateKernelWorkDimensions: can't find a WG size "
"suitable for global work size > UINT32_MAX");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
WG[I] = GroupSize[I];
}
logger::debug("calculateKernelWorkDimensions: using computed WG "
"size = {{{}, {}, {}}}",
WG[0], WG[1], WG[2]);
}
}

// TODO: assert if sizes do not fit into 32-bit?
switch (WorkDim) {
case 3:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
ZeThreadGroupDimensions.groupCountZ =
ur_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
break;
case 2:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
WG[2] = 1;
break;
case 1:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
WG[1] = WG[2] = 1;
break;

default:
logger::error("calculateKernelWorkDimensions: unsupported work_dim");
return UR_RESULT_ERROR_INVALID_VALUE;
}

// Error handling for non-uniform group size case
if (GlobalWorkSize3D[0] !=
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 1st dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[1] !=
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 2nd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[2] !=
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 3rd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}

return UR_RESULT_SUCCESS;
}

/// Helper function for finding the Level Zero events associated with the
/// commands in a command-buffer, each event is pointed to by a sync-point in
/// the wait list.
Expand Down Expand Up @@ -626,8 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
}

ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
GlobalWorkOffset[2]));
(Kernel->getZeKernel(CommandBuffer->Device), GlobalWorkOffset[0],
GlobalWorkOffset[1], GlobalWorkOffset[2]));
}

// If there are any pending arguments set them now.
Expand All @@ -640,7 +530,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
CommandBuffer->Device));
}
ZE2UR_CALL(zeKernelSetArgumentValue,
(Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
(Kernel->getZeKernel(CommandBuffer->Device), Arg.Index, Arg.Size,
ZeHandlePtr));
}
Kernel->PendingArguments.clear();

Expand All @@ -651,7 +542,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
ZeThreadGroupDimensions, WG, WorkDim,
GlobalWorkSize, LocalWorkSize));

ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
ZE2UR_CALL(zeKernelSetGroupSize,
(Kernel->getZeKernel(CommandBuffer->Device), WG[0], WG[1], WG[2]));

CommandBuffer->KernelsList.push_back(Kernel);
// Increment the reference count of the Kernel and indicate that the Kernel
Expand Down Expand Up @@ -691,7 +583,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(

if (CommandBuffer->IsInOrderCmdList) {
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
(CommandBuffer->ZeCommandList, Kernel->ZeKernel,
(CommandBuffer->ZeCommandList,
Kernel->getZeKernel(CommandBuffer->Device),
&ZeThreadGroupDimensions, nullptr, 0, nullptr));

logger::debug("calling zeCommandListAppendLaunchKernel()");
Expand All @@ -712,7 +605,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
}

ZE2UR_CALL(zeCommandListAppendLaunchKernel,
(CommandBuffer->ZeCommandList, Kernel->ZeKernel,
(CommandBuffer->ZeCommandList,
Kernel->getZeKernel(CommandBuffer->Device),
&ZeThreadGroupDimensions, LaunchEvent->ZeEvent,
ZeEventList.size(), ZeEventList.data()));

Expand Down
Loading