Skip to content

Commit

Permalink
Rebased and applied feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
ldh4 committed Aug 24, 2023
1 parent dab159d commit 09fd0c2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,9 +1108,9 @@ class where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
void gather_from(
double const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm256_mask_i32gather_pd(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256d>(m_mask), 8));
m_value = value_type(_mm256_mask_i32gather_pd(
static_cast<__m256d>(m_value), mem, static_cast<__m128i>(index),
static_cast<__m256d>(m_mask), 8));
}
template <class U,
std::enable_if_t<std::is_convertible_v<
Expand Down Expand Up @@ -1187,9 +1187,9 @@ class where_expression<simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>>,
void gather_from(
std::int32_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm_mask_i32gather_epi32(m_value, mem, static_cast<__m128i>(index),
static_cast<__m128i>(m_mask), 4));
m_value = value_type(_mm_mask_i32gather_epi32(
static_cast<__m128i>(m_value), mem, static_cast<__m128i>(index),
static_cast<__m128i>(m_mask), 4));
}
template <
class U,
Expand Down Expand Up @@ -1270,9 +1270,9 @@ class where_expression<simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>>,
void gather_from(
std::int64_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm256_mask_i32gather_epi64(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256i>(m_mask), 8));
m_value = value_type(_mm256_mask_i32gather_epi64(
static_cast<__m256i>(m_value), reinterpret_cast<long long const*>(mem),
static_cast<__m128i>(index), static_cast<__m256i>(m_mask), 8));
}
template <
class u,
Expand Down Expand Up @@ -1354,9 +1354,9 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>>,
void gather_from(
std::uint64_t const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(
_mm256_mask_i32gather_epi64(m_value, mem, static_cast<__m128i>(index),
static_cast<__m256i>(m_mask), 8));
m_value = value_type(_mm256_mask_i32gather_epi64(
static_cast<__m256i>(m_value), reinterpret_cast<long long const*>(mem),
static_cast<__m128i>(index), static_cast<__m256i>(m_mask), 8));
}
template <class u,
std::enable_if_t<
Expand Down
8 changes: 4 additions & 4 deletions simd/unit_tests/include/TestSIMD_WhereExpressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ inline void host_check_where_expr_scatter_to() {
for (std::size_t i = 0; i < nlanes; ++i) {
dst[i] = (2 + (i * 2));
index[i] = i;
expected_result[i] = (mask[i]) ? src[index[i]] : (2 + (i * 2));
expected_result[i] = (mask[i]) ? src[index[i]] : dst[i];
}
where(mask, src).scatter_to(dst, index);

Expand Down Expand Up @@ -71,7 +71,7 @@ inline void host_check_where_expr_gather_from() {
for (std::size_t i = 0; i < nlanes; ++i) {
dst[i] = (2 + (i * 2));
index[i] = i;
expected_result[i] = (mask[i]) ? src[index[i]] : (2 + (i * 2));
expected_result[i] = (mask[i]) ? src[index[i]] : dst[i];
}
where(mask, dst).gather_from(src, index);

Expand Down Expand Up @@ -119,7 +119,7 @@ KOKKOS_INLINE_FUNCTION void device_check_where_expr_scatter_to() {
for (std::size_t i = 0; i < nlanes; ++i) {
dst[i] = (2 + (i * 2));
index[i] = i;
expected_result[i] = (mask[i]) ? src[index[i]] : (2 + (i * 2));
expected_result[i] = (mask[i]) ? src[index[i]] : dst[i];
}
where(mask, src).scatter_to(dst, index);

Expand Down Expand Up @@ -149,7 +149,7 @@ KOKKOS_INLINE_FUNCTION void device_check_where_expr_gather_from() {
for (std::size_t i = 0; i < nlanes; ++i) {
dst[i] = (2 + (i * 2));
index[i] = i;
expected_result[i] = (mask[i]) ? src[index[i]] : (2 + (i * 2));
expected_result[i] = (mask[i]) ? src[index[i]] : dst[i];
}
where(mask, dst).gather_from(src, index);

Expand Down

0 comments on commit 09fd0c2

Please sign in to comment.