Skip to content

Commit

Permalink
SIMD: add shift ops for all int types (kokkos#6109)
Browse files Browse the repository at this point in the history
* Added unit tests and shift operators for avx512

* Added shifts for avx2; impl for int64_t sra missing

* Added shifts for neon

* clang-formatted

* Changed int usages to std::int32_t

* Added unit tests for shift ops with simd type as rhs

* Added device side unit test

* Revert commented out lines

* Changed argument type of rhs for simd shifts to take the same type
as lhs.
Added few workaround to avoid spurioud uninitialized variable warnings

* removed an unnecessary commented out line

* Quick fixes to neon simd

* Added a corner case check

* clang-formatted

* Fixed test cases for shift op

* clang-formatted

* Added missing nodiscard attributes

* Converted shift ops to hidden friends

* Replaced static_asserts on data types with enable_ifs
  • Loading branch information
ldh4 committed Aug 1, 2023
1 parent d89140d commit 78e8a32
Show file tree
Hide file tree
Showing 6 changed files with 685 additions and 165 deletions.
160 changes: 123 additions & 37 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,40 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
operator!=(simd const& other) const {
return !((*this) == other);
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int32_t, simd_abi::avx2_fixed_size<4>>
operator>>(simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& lhs,
int rhs) noexcept {
return simd<std::int32_t, simd_abi::avx2_fixed_size<4>>(
_mm_srai_epi32(static_cast<__m128i>(lhs), rhs));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int32_t, simd_abi::avx2_fixed_size<4>>
operator>>(
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& lhs,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
return simd<std::int32_t, simd_abi::avx2_fixed_size<4>>(
_mm_srav_epi32(static_cast<__m128i>(lhs), static_cast<__m128i>(rhs)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int32_t, simd_abi::avx2_fixed_size<4>>
operator<<(simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& lhs,
int rhs) noexcept {
return simd<std::int32_t, simd_abi::avx2_fixed_size<4>>(
_mm_slli_epi32(static_cast<__m128i>(lhs), rhs));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int32_t, simd_abi::avx2_fixed_size<4>>
operator<<(
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& lhs,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
return simd<std::int32_t, simd_abi::avx2_fixed_size<4>>(
_mm_sllv_epi32(static_cast<__m128i>(lhs), static_cast<__m128i>(rhs)));
}
};

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
Expand All @@ -661,9 +695,9 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
_mm_add_epi32(static_cast<__m128i>(lhs), static_cast<__m128i>(rhs)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> abs(
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& a) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int32_t, simd_abi::avx2_fixed_size<4>>
abs(simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& a) {
__m128i const rhs = static_cast<__m128i>(a);
return simd<std::int32_t, simd_abi::avx2_fixed_size<4>>(_mm_abs_epi32(rhs));
}
Expand Down Expand Up @@ -767,6 +801,42 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {
operator!=(simd const& other) const {
return !((*this) == other);
}

// Shift right arithmetic for 64bit packed ints is not availalbe in AVX2
// KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend
// simd<std::int64_t, simd_abi::avx2_fixed_size<4>> operator>>(
// simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
// int rhs) noexcept {
// return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(
// _mm256_srai_epi64(static_cast<__m256i>(lhs), rhs));
// }

// KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend
// simd<std::int64_t, simd_abi::avx2_fixed_size<4>> operator>>(
// simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
// simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
// return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(
// _mm256_srav_epi64(static_cast<__m256i>(lhs),
// _mm256_cvtepi32_epi64(static_cast<__m128i>(static_cast<__m128i>(rhs))));
// }

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int64_t, simd_abi::avx2_fixed_size<4>>
operator<<(simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
int rhs) noexcept {
return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(
_mm256_slli_epi64(static_cast<__m256i>(lhs), rhs));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::int64_t, simd_abi::avx2_fixed_size<4>>
operator<<(
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(
_mm256_sllv_epi64(static_cast<__m256i>(lhs),
_mm256_cvtepi32_epi64(static_cast<__m128i>(rhs))));
}
};

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
Expand All @@ -793,18 +863,18 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {

// Manually computing absolute values, because _mm256_abs_epi64
// is not in AVX2; it's available in AVX512.
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> abs(
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& a) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int64_t, simd_abi::avx2_fixed_size<4>>
abs(simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& a) {
return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(
[&](std::size_t i) { return (a[i] < 0) ? -a[i] : a[i]; });
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> condition(
simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>> const& a,
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& b,
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& c) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::int64_t, simd_abi::avx2_fixed_size<4>>
condition(simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>> const& a,
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& b,
simd<std::int64_t, simd_abi::avx2_fixed_size<4>> const& c) {
return simd<std::int64_t, simd_abi::avx2_fixed_size<4>>(_mm256_castpd_si256(
_mm256_blendv_pd(_mm256_castsi256_pd(static_cast<__m256i>(c)),
_mm256_castsi256_pd(static_cast<__m256i>(b)),
Expand Down Expand Up @@ -865,24 +935,6 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
static_cast<__m256i>(mask_type(true)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd
operator>>(unsigned int rhs) const {
return _mm256_srli_epi64(m_value, rhs);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator>>(
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& rhs) const {
return _mm256_srlv_epi64(m_value,
_mm256_cvtepi32_epi64(static_cast<__m128i>(rhs)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd
operator<<(unsigned int rhs) const {
return _mm256_slli_epi64(m_value, rhs);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator<<(
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& rhs) const {
return _mm256_sllv_epi64(m_value,
_mm256_cvtepi32_epi64(static_cast<__m128i>(rhs)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd
operator&(simd const& other) const {
return _mm256_and_si256(m_value, other.m_value);
}
Expand All @@ -902,6 +954,40 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
operator!=(simd const& other) const {
return !((*this) == other);
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::uint64_t, simd_abi::avx2_fixed_size<4>>
operator>>(simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
int rhs) noexcept {
return simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>(
_mm256_srli_epi64(static_cast<__m256i>(lhs), rhs));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::uint64_t, simd_abi::avx2_fixed_size<4>>
operator>>(
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
return simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>(_mm256_srlv_epi64(
static_cast<__m256i>(lhs), static_cast<__m256i>(rhs)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::uint64_t, simd_abi::avx2_fixed_size<4>>
operator<<(simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
int rhs) noexcept {
return simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>(
_mm256_slli_epi64(static_cast<__m256i>(lhs), rhs));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd<
std::uint64_t, simd_abi::avx2_fixed_size<4>>
operator<<(
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& lhs,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& rhs) noexcept {
return simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>(_mm256_sllv_epi64(
static_cast<__m256i>(lhs), static_cast<__m256i>(rhs)));
}
};

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
Expand All @@ -925,17 +1011,17 @@ simd<std::int64_t, simd_abi::avx2_fixed_size<4>>::simd(
_mm256_sub_epi64(static_cast<__m256i>(lhs), static_cast<__m256i>(rhs)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> abs(
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& a) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>
abs(simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& a) {
return a;
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> condition(
simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& a,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& b,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& c) {
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>
condition(simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& a,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& b,
simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> const& c) {
return simd<std::uint64_t, simd_abi::avx2_fixed_size<4>>(_mm256_castpd_si256(
_mm256_blendv_pd(_mm256_castsi256_pd(static_cast<__m256i>(c)),
_mm256_castsi256_pd(static_cast<__m256i>(b)),
Expand Down

0 comments on commit 78e8a32

Please sign in to comment.