Skip to content

Commit

Permalink
SIMD: add generator constructors (kokkos#6347)
Browse files Browse the repository at this point in the history
* wip unit test

* Added gen ctors for AVX512

* Added gen ctors for avx2

* Added gen ctors for neon

* Quick fix to missing lines
  • Loading branch information
ldh4 committed Aug 10, 2023
1 parent 2e744dc commit da49ee2
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 9 deletions.
69 changes: 66 additions & 3 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class simd_mask<double, simd_abi::avx2_fixed_size<4>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd_mask(value_type value)
: m_value(_mm256_castsi256_pd(_mm256_set1_epi64x(-std::int64_t(value)))) {
}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
G&& gen) noexcept
: m_value(_mm256_castsi256_pd(_mm256_setr_epi64x(
-std::int64_t(gen(std::integral_constant<std::size_t, 0>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 1>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 2>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 3>()))))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask(
simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>> const& i32_mask);
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() {
Expand Down Expand Up @@ -159,6 +171,18 @@ class simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
__m128i const& value_in)
: m_value(value_in) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
G&& gen) noexcept
: m_value(_mm_setr_epi32(
-std::int32_t(gen(std::integral_constant<std::size_t, 0>())),
-std::int32_t(gen(std::integral_constant<std::size_t, 1>())),
-std::int32_t(gen(std::integral_constant<std::size_t, 2>())),
-std::int32_t(gen(std::integral_constant<std::size_t, 3>())))) {}
template <class U>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask(
simd_mask<U, abi_type> const& other) {
Expand Down Expand Up @@ -244,6 +268,18 @@ class simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
__m256i const& value_in)
: m_value(value_in) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
G&& gen) noexcept
: m_value(_mm256_setr_epi64x(
-std::int64_t(gen(std::integral_constant<std::size_t, 0>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 1>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 2>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 3>())))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask(
simd_mask<std::int32_t, abi_type> const& other)
: m_value(_mm256_cvtepi32_epi64(static_cast<__m128i>(other))) {}
Expand Down Expand Up @@ -328,6 +364,18 @@ class simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
__m256i const& value_in)
: m_value(value_in) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
G&& gen) noexcept
: m_value(_mm256_setr_epi64x(
-std::int64_t(gen(std::integral_constant<std::size_t, 0>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 1>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 2>())),
-std::int64_t(gen(std::integral_constant<std::size_t, 3>())))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m256i()
const {
return m_value;
Expand Down Expand Up @@ -393,6 +441,18 @@ class simd<double, simd_abi::avx2_fixed_size<4>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
__m256d const& value_in)
: m_value(value_in) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(_mm256_setr_pd(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()))) {
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
return reinterpret_cast<value_type*>(&m_value)[i];
}
Expand Down Expand Up @@ -589,7 +649,8 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(_mm_setr_epi32(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
Expand Down Expand Up @@ -741,7 +802,8 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(_mm256_setr_epi64x(
gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -908,7 +970,8 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(_mm256_setr_epi64x(
gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down
98 changes: 97 additions & 1 deletion simd/src/Kokkos_SIMD_AVX512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ class simd_mask<T, simd_abi::avx512_fixed_size<8>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask(
simd_mask<U, simd_abi::avx512_fixed_size<8>> const& other)
: m_value(static_cast<__mmask8>(other)) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask(G&& gen) : m_value(false) {
reference(m_value, int(0)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 0>()));
reference(m_value, int(1)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 1>()));
reference(m_value, int(2)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 2>()));
reference(m_value, int(3)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 3>()));
reference(m_value, int(4)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 4>()));
reference(m_value, int(5)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 5>()));
reference(m_value, int(6)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 6>()));
reference(m_value, int(7)) =
static_cast<bool>(gen(std::integral_constant<std::size_t, 7>()));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() {
return 8;
}
Expand Down Expand Up @@ -145,7 +168,8 @@ class simd<std::int32_t, simd_abi::avx512_fixed_size<8>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen)
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(
_mm256_setr_epi32(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -309,6 +333,24 @@ class simd<std::uint32_t, simd_abi::avx512_fixed_size<8>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd(
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& other)
: m_value(static_cast<__m256i>(other)) {}
template <class G,
std::enable_if_t<
// basically, can you do { value_type r =
// gen(std::integral_constant<std::size_t, i>()); }
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(
_mm256_setr_epi32(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()),
gen(std::integral_constant<std::size_t, 4>()),
gen(std::integral_constant<std::size_t, 5>()),
gen(std::integral_constant<std::size_t, 6>()),
gen(std::integral_constant<std::size_t, 7>()))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
return reinterpret_cast<value_type*>(&m_value)[i];
}
Expand Down Expand Up @@ -455,6 +497,24 @@ class simd<std::int64_t, simd_abi::avx512_fixed_size<8>> {
: m_value(_mm512_cvtepi32_epi64(static_cast<__m256i>(other))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd(
simd<std::uint64_t, simd_abi::avx512_fixed_size<8>> const& other);
template <class G,
std::enable_if_t<
// basically, can you do { value_type r =
// gen(std::integral_constant<std::size_t, i>()); }
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(
_mm512_setr_epi64(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()),
gen(std::integral_constant<std::size_t, 4>()),
gen(std::integral_constant<std::size_t, 5>()),
gen(std::integral_constant<std::size_t, 6>()),
gen(std::integral_constant<std::size_t, 7>()))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr simd(__m512i const& value_in)
: m_value(value_in) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
Expand Down Expand Up @@ -606,6 +666,24 @@ class simd<std::uint64_t, simd_abi::avx512_fixed_size<8>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd(
simd<std::int32_t, abi_type> const& other)
: m_value(_mm512_cvtepi32_epi64(static_cast<__m256i>(other))) {}
template <class G,
std::enable_if_t<
// basically, can you do { value_type r =
// gen(std::integral_constant<std::size_t, i>()); }
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(
_mm512_setr_epi64(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()),
gen(std::integral_constant<std::size_t, 4>()),
gen(std::integral_constant<std::size_t, 5>()),
gen(std::integral_constant<std::size_t, 6>()),
gen(std::integral_constant<std::size_t, 7>()))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd(
simd<std::int64_t, abi_type> const& other)
: m_value(static_cast<__m512i>(other)) {}
Expand Down Expand Up @@ -766,6 +844,24 @@ class simd<double, simd_abi::avx512_fixed_size<8>> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
__m512d const& value_in)
: m_value(value_in) {}
template <class G,
std::enable_if_t<
// basically, can you do { value_type r =
// gen(std::integral_constant<std::size_t, i>()); }
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept
: m_value(_mm512_setr_pd(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()),
gen(std::integral_constant<std::size_t, 4>()),
gen(std::integral_constant<std::size_t, 5>()),
gen(std::integral_constant<std::size_t, 6>()),
gen(std::integral_constant<std::size_t, 7>()))) {
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
return reinterpret_cast<value_type*>(&m_value)[i];
}
Expand Down
50 changes: 46 additions & 4 deletions simd/src/Kokkos_SIMD_NEON.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ class neon_mask<Derived, 64> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask() = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit neon_mask(value_type value)
: m_value(vmovq_n_u64(value ? 0xFFFFFFFFFFFFFFFFULL : 0)) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit neon_mask(
G&& gen) noexcept {
m_value = vsetq_lane_u64(
(gen(std::integral_constant<std::size_t, 0>()) ? 0xFFFFFFFFFFFFFFFFULL
: 0),
m_value, 0);
m_value = vsetq_lane_u64(
(gen(std::integral_constant<std::size_t, 1>()) ? 0xFFFFFFFFFFFFFFFFULL
: 0),
m_value, 1);
}
template <class U>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask(
neon_mask<U, 32> const& other) {
Expand Down Expand Up @@ -175,6 +191,20 @@ class neon_mask<Derived, 32> {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask() = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit neon_mask(value_type value)
: m_value(vmov_n_u32(value ? 0xFFFFFFFFU : 0)) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit neon_mask(
G&& gen) noexcept {
m_value = vset_lane_u32(
(gen(std::integral_constant<std::size_t, 0>()) ? 0xFFFFFFFFU : 0),
m_value, 0);
m_value = vset_lane_u32(
(gen(std::integral_constant<std::size_t, 1>()) ? 0xFFFFFFFFU : 0),
m_value, 1);
}
template <class U>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask(neon_mask<U, 64> const& other)
: m_value(vqmovn_u64(static_cast<uint64x2_t>(other))) {}
Expand Down Expand Up @@ -246,6 +276,14 @@ class simd_mask<T, simd_abi::neon_fixed_size<2>>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
implementation_type const& value)
: base_type(value) {}
template <class G,
std::enable_if_t<
std::is_invocable_r_v<typename base_type::value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
G&& gen) noexcept
: base_type(gen) {}
};

template <>
Expand Down Expand Up @@ -299,7 +337,8 @@ class simd<double, simd_abi::neon_fixed_size<2>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen) {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept {
m_value = vsetq_lane_f64(gen(std::integral_constant<std::size_t, 0>()),
m_value, 0);
m_value = vsetq_lane_f64(gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -502,7 +541,8 @@ class simd<std::int32_t, simd_abi::neon_fixed_size<2>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen) {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept {
m_value = vset_lane_s32(gen(std::integral_constant<std::size_t, 0>()),
m_value, 0);
m_value = vset_lane_s32(gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -678,7 +718,8 @@ class simd<std::int64_t, simd_abi::neon_fixed_size<2>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen) {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept {
m_value = vsetq_lane_s64(gen(std::integral_constant<std::size_t, 0>()),
m_value, 0);
m_value = vsetq_lane_s64(gen(std::integral_constant<std::size_t, 1>()),
Expand Down Expand Up @@ -855,7 +896,8 @@ class simd<std::uint64_t, simd_abi::neon_fixed_size<2>> {
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(G&& gen) {
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
G&& gen) noexcept {
m_value = vsetq_lane_u64(gen(std::integral_constant<std::size_t, 0>()),
m_value, 0);
m_value = vsetq_lane_u64(gen(std::integral_constant<std::size_t, 1>()),
Expand Down

0 comments on commit da49ee2

Please sign in to comment.