Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/include/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ getSPIRVMemorySemanticsMask(const access::fence_space AccessSpace,

// To ensure loop unrolling is done when processing dimensions.
template <size_t... Inds, class F>
void loop_impl(std::integer_sequence<size_t, Inds...>, F &&f) {
constexpr void loop_impl(std::integer_sequence<size_t, Inds...>, F &&f) {
(f(std::integral_constant<size_t, Inds>{}), ...);
}

template <size_t count, class F> void loop(F &&f) {
template <size_t count, class F> constexpr void loop(F &&f) {
loop_impl(std::make_index_sequence<count>{}, std::forward<F>(f));
}
inline constexpr bool is_power_of_two(int x) { return (x & (x - 1)) == 0; }
Expand Down
50 changes: 26 additions & 24 deletions sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <sycl/ext/oneapi/experimental/detail/invoke_simd_types.hpp>
#include <sycl/ext/oneapi/experimental/uniform.hpp>

#include <sycl/detail/boost/mp11.hpp>
#include <sycl/sub_group.hpp>

#include <functional>
Expand Down Expand Up @@ -71,8 +70,6 @@ namespace ext::oneapi::experimental {
// --- Helpers
namespace detail {

namespace __MP11_NS = sycl::detail::boost::mp11;

// This structure performs the SPMD-to-SIMD parameter type conversion as defined
// by the spec.
template <class T, int N, class = void> struct spmd2simd;
Expand Down Expand Up @@ -154,8 +151,7 @@ struct is_simd_or_mask_type<simd_mask<T, N>> : std::true_type {};
// Checks if all the types in the parameter pack are uniform<T>.
template <class... SpmdArgs> struct all_uniform_types {
constexpr operator bool() {
using TypeList = __MP11_NS::mp_list<SpmdArgs...>;
return __MP11_NS::mp_all_of<TypeList, is_uniform_type>::value;
return ((is_uniform_type<SpmdArgs>::value && ...));
}
};

Expand Down Expand Up @@ -193,26 +189,32 @@ constexpr void verify_return_type_matches_sg_size() {
// as prescribed by the spec assuming this subgroup size. One and only one
// subgroup size should conform.
template <class SimdCallable, class... SpmdArgs> struct sg_size {
template <class N>
using IsInvocableSgSize = __MP11_NS::mp_bool<std::is_invocable_v<
SimdCallable, typename spmd2simd<SpmdArgs, N::value>::type...>>;

__DPCPP_SYCL_EXTERNAL constexpr operator int() {
using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>;
using InvocableSgSizes =
__MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>;
constexpr auto found_invoke_simd_target =
__MP11_NS::mp_empty<InvocableSgSizes>::value != 1;
if constexpr (found_invoke_simd_target) {
static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
"multiple invoke_simd targets found");
return __MP11_NS::mp_front<InvocableSgSizes>::value;
}
static_assert(
found_invoke_simd_target,
"No callable invoke_simd target found. Confirm the "
"invoke_simd invocation argument types are convertible to the "
"invoke_simd target argument types");
constexpr auto x = []() constexpr {
constexpr int supported_sg_sizes[] = {1, 2, 4, 8, 16, 32};
int num_found = 0;
int found_sg_size = 0;
sycl::detail::loop<std::size(supported_sg_sizes)>([&](auto idx) {
constexpr auto sg_size = supported_sg_sizes[idx];
if (std::is_invocable_v<
SimdCallable, typename spmd2simd<SpmdArgs, sg_size>::type...>) {
++num_found;
found_sg_size = sg_size;
}
});
return std::pair{num_found, found_sg_size};
}();

constexpr auto num_found = x.first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: slightly less verbose with structured binding.

Copy link
Contributor Author

@aelovikov-intel aelovikov-intel Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not allowed in constexpr context, sadly. I tried constexpr auto [...] = []() { .... }();.

constexpr auto found_sg_size = x.second;

static_assert(num_found != 0,
"No callable invoke_simd target found. Confirm the "
"invoke_simd invocation argument types are convertible to "
"the invoke_simd target argument types");
static_assert(num_found == 1, "Multiple invoke_simd targets found!");

return found_sg_size;
}
};

Expand Down
2 changes: 1 addition & 1 deletion sycl/test/invoke_simd/no_callee_found.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ void foo() {

int main() {
foo();
// CHECK: {{.*}}error:{{.*}}static assertion failed due to requirement 'found_invoke_simd_target': No callable invoke_simd target found. Confirm the invoke_simd invocation argument types are convertible to the invoke_simd target argument types{{.*}}
// CHECK: {{.*}}error:{{.*}}static assertion failed due to requirement 'num_found != 0': No callable invoke_simd target found. Confirm the invoke_simd invocation argument types are convertible to the invoke_simd target argument types{{.*}}
}
Loading