Skip to content

Commit

Permalink
Fix host-device annotation for where_expression/const_where_expression
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Jun 15, 2023
1 parent fff1bc6 commit 4bfa10e
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions simd/src/Kokkos_SIMD_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace Kokkos {

namespace Experimental {

namespace simd_abi {
class scalar;
}

template <class T, class Abi>
class simd;

Expand Down Expand Up @@ -92,20 +96,40 @@ class where_expression<bool, T> : public const_where_expression<bool, T> {
};

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
where_expression<simd_mask<T, Abi>, simd<T, Abi>>
where(typename simd<T, Abi>::mask_type const& mask, simd<T, Abi>& value) {
return where_expression(mask, value);
}

template <class T, class Abi>
template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
where_expression<simd_mask<T, Kokkos::Experimental::simd_abi::scalar>,
simd<T, Kokkos::Experimental::simd_abi::scalar>>
where(typename simd<
T, Kokkos::Experimental::simd_abi::scalar>::mask_type const& mask,
simd<T, Kokkos::Experimental::simd_abi::scalar>& value) {
return where_expression(mask, value);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>>
where(typename simd<T, Abi>::mask_type const& mask,
simd<T, Abi> const& value) {
return const_where_expression(mask, value);
}

template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION
const_where_expression<simd_mask<T, Kokkos::Experimental::simd_abi::scalar>,
simd<T, Kokkos::Experimental::simd_abi::scalar>>
where(typename simd<
T, Kokkos::Experimental::simd_abi::scalar>::mask_type const& mask,
simd<T, Kokkos::Experimental::simd_abi::scalar> const& value) {
return const_where_expression(mask, value);
}

template <class T>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION where_expression<bool, T> where(
bool mask, T& value) {
Expand Down Expand Up @@ -308,38 +332,46 @@ KOKKOS_FORCEINLINE_FUNCTION where_expression<M, T>& operator/=(
// fallback implementations of reductions across simd_mask:

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool all_of(
simd_mask<T, Abi> const& a) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::enable_if_t<
!std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
all_of(simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(true);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool any_of(
simd_mask<T, Abi> const& a) {
return a != simd_mask<T, Abi>(false);
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<
std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
all_of(simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(true);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION bool none_of(
simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(false);
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::enable_if_t<
!std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
any_of(simd_mask<T, Abi> const& a) {
return a != simd_mask<T, Abi>(false);
}

namespace Impl {

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr auto const& mask(
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
return x.m_mask;
template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<
std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
any_of(simd_mask<T, Abi> const& a) {
return a != simd_mask<T, Abi>(false);
}

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr auto const& value(
const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
return x.m_value;
template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::enable_if_t<
!std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
none_of(simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(false);
}

} // namespace Impl
template <class T, class Abi>
[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<
std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>, bool>
none_of(simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(false);
}

template <typename T, typename Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
Expand Down

0 comments on commit 4bfa10e

Please sign in to comment.