Skip to content

Commit

Permalink
[host] Fix hostQueryLocalSizeForSubGroupCount
Browse files Browse the repository at this point in the history
When we fixed HostKernel::queryLocalSizeForSubGroupCount last year, we
missed that the JIT and non-JIT paths go through different code paths,
and the same fix is also needed for hostQueryLocalSizeForSubGroupCount.
  • Loading branch information
hvdijk committed May 22, 2024
1 parent 0a12601 commit 70d6521
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions modules/mux/targets/host/source/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,28 @@ mux_result_t hostQueryLocalSizeForSubGroupCount(mux_kernel_t kernel,
size_t *local_size_x,
size_t *local_size_y,
size_t *local_size_z) {
(void)kernel;
// FIXME: For a single sub-group, we know we can satisfy that with a
// work-group of 1,1,1. For any other sub-group count, we should ensure that
// the work-group size we report comes back through getKernelVariantForWGSize
// when it comes to run it. See CA-4784.
if (sub_group_count == 1) {
*local_size_x = 1;
host::kernel_variant_s variant;
auto host_kernel = static_cast<host::kernel_s *>(kernel);
const auto &info = *host_kernel->device->info;
const auto max_local_size_x = info.max_work_group_size_x;
auto err =
host_kernel->getKernelVariantForWGSize(max_local_size_x, 1, 1, &variant);
if (err != mux_success) {
return err;
}

// If we've compiled with degenerate sub-groups, the work-group size is the
// sub-group size.
const auto local_size = [&]() -> size_t {
if (variant.sub_group_size == 0) {
return sub_group_count == 1 ? max_local_size_x : 0;
} else {
const auto local_size = sub_group_count * variant.sub_group_size;
return local_size <= max_local_size_x ? local_size : 0;
}
}();
if (local_size) {
*local_size_x = local_size;
*local_size_y = 1;
*local_size_z = 1;
} else {
Expand Down

0 comments on commit 70d6521

Please sign in to comment.