diff --git a/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl b/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl index 8ab4afe1ae05f..70f357c015b4a 100644 --- a/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl +++ b/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl @@ -13,14 +13,11 @@ #include _CLC_OVERLOAD _CLC_DEF uint __clc_get_sub_group_size() { - if (__clc_get_sub_group_id() != __clc_get_num_sub_groups() - 1) { - return __clc_get_max_sub_group_size(); - } - size_t size_x = __clc_get_local_size(0); - size_t size_y = __clc_get_local_size(1); - size_t size_z = __clc_get_local_size(2); - size_t linear_size = size_z * size_y * size_x; - size_t uniform_groups = __clc_get_num_sub_groups() - 1; - size_t uniform_size = __clc_get_max_sub_group_size() * uniform_groups; - return linear_size - uniform_size; + size_t linear_size = __clc_get_local_size(0) * __clc_get_local_size(1) * + __clc_get_local_size(2); + uint remainder = linear_size % __clc_get_max_sub_group_size(); + bool full_sub_group = (remainder == 0) || (__clc_get_sub_group_id() < + __clc_get_num_sub_groups() - 1); + + return full_sub_group ? __clc_get_max_sub_group_size() : remainder; }