Skip to content

Commit

Permalink
Added gather_from and scatter_to for AVX2 and AVX512 simd
Browse files Browse the repository at this point in the history
Added gather_from and scatter_to for NEON simd
  • Loading branch information
ldh4 committed Aug 24, 2023
1 parent c740f55 commit dab159d
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 4 deletions.
54 changes: 51 additions & 3 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,9 +1108,9 @@ class where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
void gather_from(
double const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(_mm256_mask_i32gather_pd(
_mm256_set1_pd(0.0), mem, static_cast<__m128i>(index),
static_cast<__m256d>(m_mask), 8));
m_value = value_type(
_mm256_mask_i32gather_pd(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256d>(m_mask), 8));
}
template <class U,
std::enable_if_t<std::is_convertible_v<
Expand Down Expand Up @@ -1148,6 +1148,14 @@ class const_where_expression<
_mm_maskstore_epi32(mem, static_cast<__m128i>(m_mask),
static_cast<__m128i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int32_t* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
for (std::size_t lane = 0; lane < 4; ++lane) {
if (m_mask[lane]) mem[index[lane]] = m_value[lane];
}
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1175,6 +1183,14 @@ class where_expression<simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>>,
void copy_from(std::int32_t const* mem, element_aligned_tag) {
m_value = value_type(_mm_maskload_epi32(mem, static_cast<__m128i>(m_mask)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int32_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm_mask_i32gather_epi32(m_value, mem, static_cast<__m128i>(index),
static_cast<__m128i>(m_mask), 4));
}
template <
class U,
std::enable_if_t<std::is_convertible_v<
Expand Down Expand Up @@ -1214,6 +1230,14 @@ class const_where_expression<
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int64_t* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
for (std::size_t lane = 0; lane < 4; ++lane) {
if (m_mask[lane]) mem[index[lane]] = m_value[lane];
}
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1242,6 +1266,14 @@ class where_expression<simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>>,
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int64_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm256_mask_i32gather_epi64(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256i>(m_mask), 8));
}
template <
class u,
std::enable_if_t<std::is_convertible_v<
Expand Down Expand Up @@ -1282,6 +1314,14 @@ class const_where_expression<
static_cast<__m256i>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::uint64_t* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
for (std::size_t lane = 0; lane < 4; ++lane) {
if (m_mask[lane]) mem[index[lane]] = m_value[lane];
}
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1310,6 +1350,14 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>>,
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::uint64_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm256_mask_i32gather_epi64(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256i>(m_mask), 8));
}
template <class u,
std::enable_if_t<
std::is_convertible_v<
Expand Down
66 changes: 65 additions & 1 deletion simd/src/Kokkos_SIMD_AVX512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ class where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
double const* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) {
m_value = value_type(_mm512_mask_i32gather_pd(
_mm512_set1_pd(0.0), static_cast<__mmask8>(m_mask),
static_cast<__m512d>(m_value), static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index), mem, 8));
}
template <class U, std::enable_if_t<
Expand Down Expand Up @@ -1081,6 +1081,14 @@ class const_where_expression<
_mm256_mask_storeu_epi32(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int32_t* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) const {
_mm256_mask_i32scatter_epi32(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index),
static_cast<__m256i>(m_value), 4);
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1109,6 +1117,14 @@ class where_expression<simd_mask<std::int32_t, simd_abi::avx512_fixed_size<8>>,
m_value = value_type(_mm256_mask_loadu_epi32(
_mm256_set1_epi32(0), static_cast<__mmask8>(m_mask), mem));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int32_t const* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) {
m_value = value_type(_mm256_mmask_i32gather_epi32(
static_cast<__m256i>(m_value), static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index), mem, 4));
}
template <class U,
std::enable_if_t<
std::is_convertible_v<
Expand Down Expand Up @@ -1147,6 +1163,14 @@ class const_where_expression<
_mm256_mask_storeu_epi32(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::uint32_t* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) const {
_mm256_mask_i32scatter_epi32(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index),
static_cast<__m256i>(m_value), 4);
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1175,6 +1199,14 @@ class where_expression<simd_mask<std::uint32_t, simd_abi::avx512_fixed_size<8>>,
m_value = value_type(_mm256_mask_loadu_epi32(
_mm256_set1_epi32(0), static_cast<__mmask8>(m_mask), mem));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::uint32_t const* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) {
m_value = value_type(_mm256_mmask_i32gather_epi32(
static_cast<__m256i>(m_value), static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index), mem, 4));
}
template <class U,
std::enable_if_t<
std::is_convertible_v<
Expand Down Expand Up @@ -1213,6 +1245,14 @@ class const_where_expression<
_mm512_mask_storeu_epi64(mem, static_cast<__mmask8>(m_mask),
static_cast<__m512i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int64_t* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) const {
_mm512_mask_i32scatter_epi64(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index),
static_cast<__m512i>(m_value), 8);
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1241,6 +1281,14 @@ class where_expression<simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
m_value = value_type(_mm512_mask_loadu_epi64(
_mm512_set1_epi64(0.0), static_cast<__mmask8>(m_mask), mem));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int64_t const* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) {
m_value = value_type(_mm512_mask_i32gather_epi64(
static_cast<__m512i>(m_value), static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index), mem, 8));
}
template <class U,
std::enable_if_t<
std::is_convertible_v<
Expand Down Expand Up @@ -1279,6 +1327,14 @@ class const_where_expression<
_mm512_mask_storeu_epi64(mem, static_cast<__mmask8>(m_mask),
static_cast<__m512i>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::uint64_t* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) const {
_mm512_mask_i32scatter_epi64(mem, static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index),
static_cast<__m512i>(m_value), 8);
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1307,6 +1363,14 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
m_value = value_type(_mm512_mask_loadu_epi64(
_mm512_set1_epi64(0.0), static_cast<__mmask8>(m_mask), mem));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::uint64_t const* mem,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>> const& index) {
m_value = value_type(_mm512_mask_i32gather_epi64(
static_cast<__m512i>(m_value), static_cast<__mmask8>(m_mask),
static_cast<__m256i>(index), mem, 8));
}
template <class U,
std::enable_if_t<
std::is_convertible_v<
Expand Down
42 changes: 42 additions & 0 deletions simd/src/Kokkos_SIMD_NEON.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ class const_where_expression<
if (m_mask[0]) mem[0] = m_value[0];
if (m_mask[1]) mem[1] = m_value[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int32_t* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) const {
if (m_mask[0]) mem[index[0]] = m_value[0];
if (m_mask[1]) mem[index[1]] = m_value[1];
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1105,6 +1112,13 @@ class where_expression<simd_mask<std::int32_t, simd_abi::neon_fixed_size<2>>,
if (m_mask[0]) m_value[0] = mem[0];
if (m_mask[1]) m_value[1] = mem[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int32_t const* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) {
if (m_mask[0]) m_value[0] = mem[index[0]];
if (m_mask[1]) m_value[1] = mem[index[1]];
}
template <
class U,
std::enable_if_t<
Expand Down Expand Up @@ -1143,6 +1157,13 @@ class const_where_expression<
if (m_mask[0]) mem[0] = m_value[0];
if (m_mask[1]) mem[1] = m_value[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::int64_t* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) const {
if (m_mask[0]) mem[index[0]] = m_value[0];
if (m_mask[1]) mem[index[1]] = m_value[1];
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1171,6 +1192,13 @@ class where_expression<simd_mask<std::int64_t, simd_abi::neon_fixed_size<2>>,
if (m_mask[0]) m_value[0] = mem[0];
if (m_mask[1]) m_value[1] = mem[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::int64_t const* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) {
if (m_mask[0]) m_value[0] = mem[index[0]];
if (m_mask[1]) m_value[1] = mem[index[1]];
}
template <
class U,
std::enable_if_t<std::is_convertible_v<
Expand Down Expand Up @@ -1209,6 +1237,13 @@ class const_where_expression<
if (m_mask[0]) mem[0] = m_value[0];
if (m_mask[1]) mem[1] = m_value[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
std::uint64_t* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) const {
if (m_mask[0]) mem[index[0]] = m_value[0];
if (m_mask[1]) mem[index[1]] = m_value[1];
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
Expand Down Expand Up @@ -1237,6 +1272,13 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::neon_fixed_size<2>>,
if (m_mask[0]) m_value[0] = mem[0];
if (m_mask[1]) m_value[1] = mem[1];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
std::uint64_t const* mem,
simd<std::int32_t, simd_abi::neon_fixed_size<2>> const& index) {
if (m_mask[0]) m_value[0] = mem[index[0]];
if (m_mask[1]) m_value[1] = mem[index[1]];
}
template <class U,
std::enable_if_t<
std::is_convertible_v<
Expand Down
1 change: 1 addition & 0 deletions simd/unit_tests/TestSIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
#include <TestSIMD_ShiftOps.hpp>
#include <TestSIMD_Condition.hpp>
#include <TestSIMD_GeneratorCtors.hpp>
#include <TestSIMD_WhereExpressions.hpp>

0 comments on commit dab159d

Please sign in to comment.