Skip to content

Commit

Permalink
Introduce impl_get_value/impl_get_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Jun 15, 2023
1 parent 0de2de2 commit 4a57a2e
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 76 deletions.
48 changes: 32 additions & 16 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,15 @@ class const_where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
}
}

friend constexpr auto const& Impl::mask<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1026,11 +1030,15 @@ class const_where_expression<
static_cast<__m128i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1088,11 +1096,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1152,11 +1164,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down
76 changes: 48 additions & 28 deletions simd/src/Kokkos_SIMD_AVX512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,11 +854,15 @@ class const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
static_cast<__m512d>(m_value), 8);
}

friend constexpr auto const& Impl::mask<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -922,11 +926,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -984,11 +992,15 @@ class const_where_expression<
static_cast<__m256i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1046,11 +1058,15 @@ class const_where_expression<
static_cast<__m512i>(m_value));
}

friend constexpr auto const& Impl::mask<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1108,11 +1124,15 @@ class const_where_expression<
static_cast<__m512i>(m_value));
}

friend constexpr auto const& Impl::mask<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1152,34 +1172,34 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
simd_mask<std::int32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_max_epi32(
static_cast<__mmask8>(Impl::mask(x)),
_mm512_castsi256_si512(static_cast<__m256i>(Impl::value(x))));
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmin(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
x) {
return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512d>(Impl::value(x)));
return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512d>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce(
const_where_expression<
simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int64_t, simd_abi::avx512_fixed_size<8>>> const& x,
std::int64_t, std::plus<>) {
return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512i>(Impl::value(x)));
return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
x,
double, std::plus<>) {
return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(Impl::mask(x)),
static_cast<__m512d>(Impl::value(x)));
return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512d>(x.impl_get_value()));
}

} // namespace Experimental
Expand Down
12 changes: 6 additions & 6 deletions simd/src/Kokkos_SIMD_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ template <typename T, typename Abi>
template <typename T, typename Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
hmin(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::min();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::min(result, v[i]);
Expand All @@ -356,8 +356,8 @@ hmin(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
hmax(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::max();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::max(result, v[i]);
Expand All @@ -369,8 +369,8 @@ template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x, T,
std::plus<>) {
auto const& v = Impl::value(x);
auto const& m = Impl::mask(x);
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::sum();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result += v[i];
Expand Down
48 changes: 32 additions & 16 deletions simd/src/Kokkos_SIMD_NEON.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,11 +898,15 @@ class const_where_expression<simd_mask<double, simd_abi::neon_fixed_size<2>>,
if (m_mask[1]) mem[index[1]] = m_value[1];
}

friend constexpr auto const& Impl::mask<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<double, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -966,11 +970,15 @@ class const_where_expression<
if (m_mask[1]) mem[1] = m_value[1];
}

friend constexpr auto const& Impl::mask<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int32_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1028,11 +1036,15 @@ class const_where_expression<
if (m_mask[1]) mem[1] = m_value[1];
}

friend constexpr auto const& Impl::mask<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::int64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down Expand Up @@ -1090,11 +1102,15 @@ class const_where_expression<
if (m_mask[1]) mem[1] = m_value[1];
}

friend constexpr auto const& Impl::mask<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend constexpr auto const& Impl::value<std::uint64_t, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
Expand Down
25 changes: 15 additions & 10 deletions simd/src/Kokkos_SIMD_Scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,15 @@ class const_where_expression<simd_mask<T, simd_abi::scalar>,
mem[static_cast<Integral>(index)] = static_cast<T>(m_value);
}

friend KOKKOS_FUNCTION constexpr auto const& Impl::mask<T, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

friend KOKKOS_FUNCTION constexpr auto const& Impl::value<T, abi_type>(
const_where_expression<mask_type, value_type> const& x);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <class T>
Expand Down Expand Up @@ -302,25 +306,26 @@ template <class T>
reduce(const_where_expression<simd_mask<T, simd_abi::scalar>,
simd<T, simd_abi::scalar>> const& x,
T identity_element, std::plus<>) {
return static_cast<bool>(Impl::mask(x)) ? static_cast<T>(Impl::value(x))
: identity_element;
return static_cast<bool>(x.impl_get_mask())
? static_cast<T>(x.impl_get_value())
: identity_element;
}

template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T
hmax(const_where_expression<simd_mask<T, simd_abi::scalar>,
simd<T, simd_abi::scalar>> const& x) {
return static_cast<bool>(Impl::mask(x))
? static_cast<T>(Impl::value(x))
return static_cast<bool>(x.impl_get_mask())
? static_cast<T>(x.impl_get_value())
: Kokkos::reduction_identity<T>::max();
}

template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T
hmin(const_where_expression<simd_mask<T, simd_abi::scalar>,
simd<T, simd_abi::scalar>> const& x) {
return static_cast<bool>(Impl::mask(x))
? static_cast<T>(Impl::value(x))
return static_cast<bool>(x.impl_get_mask())
? static_cast<T>(x.impl_get_value())
: Kokkos::reduction_identity<T>::min();
}

Expand Down

0 comments on commit 4a57a2e

Please sign in to comment.