Skip to content

Commit

Permalink
SIMD: add float simd support (kokkos#6177)
Browse files Browse the repository at this point in the history
* Added float simd for AVX2

* Added float simd for AVX512

* Added float simd in NEON

* WIP: adding in scatter_to/gather_from

need tests and impls in other backends as well first

* Added scatter_to/gather_from for float simd

* Added cbrt, exp and log for float simd

* Fixed gather_from in avx512

* Converted binary ops to hidden friends

* clang-formatted

* Fixes based on feedback. Reordering and removing unnecessary constructors.
  • Loading branch information
ldh4 committed Aug 30, 2023
1 parent 4617e4a commit e76708a
Show file tree
Hide file tree
Showing 4 changed files with 877 additions and 4 deletions.
8 changes: 4 additions & 4 deletions simd/src/Kokkos_SIMD.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,19 @@ class data_types {};
#if defined(KOKKOS_ARCH_AVX512XEON)
using host_abi_set = abi_set<simd_abi::scalar, simd_abi::avx512_fixed_size<8>>;
using data_type_set = data_types<std::int32_t, std::uint32_t, std::int64_t,
std::uint64_t, double>;
std::uint64_t, double, float>;
#elif defined(KOKKOS_ARCH_AVX2)
using host_abi_set = abi_set<simd_abi::scalar, simd_abi::avx2_fixed_size<4>>;
using data_type_set =
data_types<std::int32_t, std::int64_t, std::uint64_t, double>;
data_types<std::int32_t, std::int64_t, std::uint64_t, double, float>;
#elif defined(__ARM_NEON)
using host_abi_set = abi_set<simd_abi::scalar, simd_abi::neon_fixed_size<2>>;
using data_type_set =
data_types<std::int32_t, std::int64_t, std::uint64_t, double>;
data_types<std::int32_t, std::int64_t, std::uint64_t, double, float>;
#else
using host_abi_set = abi_set<simd_abi::scalar>;
using data_type_set = data_types<std::int32_t, std::uint32_t, std::int64_t,
std::uint64_t, double>;
std::uint64_t, double, float>;
#endif

using device_abi_set = abi_set<simd_abi::scalar>;
Expand Down
338 changes: 338 additions & 0 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,82 @@ class simd_mask<double, simd_abi::avx2_fixed_size<4>> {
}
};

template <>
class simd_mask<float, simd_abi::avx2_fixed_size<4>> {
__m128 m_value;

public:
class reference {
__m128& m_mask;
int m_lane;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION __m128 bit_mask() const {
return _mm_castsi128_ps(_mm_setr_epi32(
-std::int32_t(m_lane == 0), -std::int32_t(m_lane == 1),
-std::int32_t(m_lane == 2), -std::int32_t(m_lane == 3)));
}

public:
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(__m128& mask_arg,
int lane_arg)
: m_mask(mask_arg), m_lane(lane_arg) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference
operator=(bool value) const {
if (value) {
m_mask = _mm_or_ps(bit_mask(), m_mask);
} else {
m_mask = _mm_andnot_ps(bit_mask(), m_mask);
}
return *this;
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator bool() const {
return (_mm_movemask_ps(m_mask) & (1 << m_lane)) != 0;
}
};
using value_type = bool;
using abi_type = simd_abi::avx2_fixed_size<4>;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask() = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd_mask(value_type value)
: m_value(_mm_castsi128_ps(_mm_set1_epi32(-std::int32_t(value)))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() {
return 4;
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask(
__m128 const& value_in)
: m_value(value_in) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m128()
const {
return m_value;
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
return reference(m_value, int(i));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type
operator[](std::size_t i) const {
return static_cast<value_type>(
reference(const_cast<__m128&>(m_value), int(i)));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask
operator||(simd_mask const& other) const {
return simd_mask(_mm_or_ps(m_value, other.m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask
operator&&(simd_mask const& other) const {
return simd_mask(_mm_and_ps(m_value, other.m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask operator!() const {
auto const true_value = static_cast<__m128>(simd_mask(true));
return simd_mask(_mm_andnot_ps(m_value, true_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator==(
simd_mask const& other) const {
return _mm_movemask_ps(m_value) == _mm_movemask_ps(other.m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator!=(
simd_mask const& other) const {
return !operator==(other);
}
};

template <>
class simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>> {
__m128i m_value;
Expand Down Expand Up @@ -620,6 +696,189 @@ simd<double, simd_abi::avx2_fixed_size<4>> condition(
static_cast<__m256d>(a)));
}

template <>
class simd<float, simd_abi::avx2_fixed_size<4>> {
__m128 m_value;

public:
using value_type = float;
using abi_type = simd_abi::avx2_fixed_size<4>;
using mask_type = simd_mask<value_type, abi_type>;
using reference = value_type&;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd() = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd const&) = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd&&) = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd const&) = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd&&) = default;
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() {
return 4;
}
template <typename U, std::enable_if_t<std::is_convertible_v<U, value_type>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(U&& value)
: m_value(_mm_set1_ps(value_type(value))) {}
template <typename G,
std::enable_if_t<
std::is_invocable_r_v<value_type, G,
std::integral_constant<std::size_t, 0>>,
bool> = false>
KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen)
: m_value(_mm_setr_ps(gen(std::integral_constant<std::size_t, 0>()),
gen(std::integral_constant<std::size_t, 1>()),
gen(std::integral_constant<std::size_t, 2>()),
gen(std::integral_constant<std::size_t, 3>()))) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd(
__m128 const& value_in)
: m_value(value_in) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) {
return reinterpret_cast<value_type*>(&m_value)[i];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type
operator[](std::size_t i) const {
return reinterpret_cast<value_type const*>(&m_value)[i];
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
element_aligned_tag) {
m_value = _mm_loadu_ps(ptr);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
_mm_storeu_ps(ptr, m_value);
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m128()
const {
return m_value;
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator-() const
noexcept {
return simd(_mm_sub_ps(_mm_set1_ps(0.0), m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd operator*(
simd const& lhs, simd const& rhs) noexcept {
return simd(_mm_mul_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd operator/(
simd const& lhs, simd const& rhs) noexcept {
return simd(_mm_div_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd operator+(
simd const& lhs, simd const& rhs) noexcept {
return simd(_mm_add_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend simd operator-(
simd const& lhs, simd const& rhs) noexcept {
return simd(_mm_sub_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator<(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmplt_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator>(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmpgt_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator<=(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmple_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator>=(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmpge_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator==(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmpeq_ps(lhs.m_value, rhs.m_value));
}
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION friend mask_type
operator!=(simd const& lhs, simd const& rhs) noexcept {
return mask_type(_mm_cmpneq_ps(lhs.m_value, rhs.m_value));
}
};

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> copysign(
simd<float, simd_abi::avx2_fixed_size<4>> const& a,
simd<float, simd_abi::avx2_fixed_size<4>> const& b) {
__m128 const sign_mask = _mm_set1_ps(-0.0);
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_xor_ps(_mm_andnot_ps(sign_mask, static_cast<__m128>(a)),
_mm_and_ps(sign_mask, static_cast<__m128>(b))));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> abs(
simd<float, simd_abi::avx2_fixed_size<4>> const& a) {
__m128 const sign_mask = _mm_set1_ps(-0.0);
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_andnot_ps(sign_mask, static_cast<__m128>(a)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> sqrt(
simd<float, simd_abi::avx2_fixed_size<4>> const& a) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_sqrt_ps(static_cast<__m128>(a)));
}

#ifdef __INTEL_COMPILER

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> cbrt(
simd<float, simd_abi::avx2_fixed_size<4>> const& a) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_cbrt_ps(static_cast<__m128>(a)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> exp(
simd<float, simd_abi::avx2_fixed_size<4>> const& a) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_exp_ps(static_cast<__m128>(a)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> log(
simd<float, simd_abi::avx2_fixed_size<4>> const& a) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_log_ps(static_cast<__m128>(a)));
}

#endif

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> fma(
simd<float, simd_abi::avx2_fixed_size<4>> const& a,
simd<float, simd_abi::avx2_fixed_size<4>> const& b,
simd<float, simd_abi::avx2_fixed_size<4>> const& c) {
return simd<float, simd_abi::avx2_fixed_size<4>>(_mm_fmadd_ps(
static_cast<__m128>(a), static_cast<__m128>(b), static_cast<__m128>(c)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> max(
simd<float, simd_abi::avx2_fixed_size<4>> const& a,
simd<float, simd_abi::avx2_fixed_size<4>> const& b) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_max_ps(static_cast<__m128>(a), static_cast<__m128>(b)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> min(
simd<float, simd_abi::avx2_fixed_size<4>> const& a,
simd<float, simd_abi::avx2_fixed_size<4>> const& b) {
return simd<float, simd_abi::avx2_fixed_size<4>>(
_mm_min_ps(static_cast<__m128>(a), static_cast<__m128>(b)));
}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
simd<float, simd_abi::avx2_fixed_size<4>> condition(
simd_mask<float, simd_abi::avx2_fixed_size<4>> const& a,
simd<float, simd_abi::avx2_fixed_size<4>> const& b,
simd<float, simd_abi::avx2_fixed_size<4>> const& c) {
return simd<float, simd_abi::avx2_fixed_size<4>>(_mm_blendv_ps(
static_cast<__m128>(c), static_cast<__m128>(b), static_cast<__m128>(a)));
}

template <>
class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
__m128i m_value;
Expand Down Expand Up @@ -1126,6 +1385,85 @@ class where_expression<simd_mask<double, simd_abi::avx2_fixed_size<4>>,
}
};

template <>
class const_where_expression<simd_mask<float, simd_abi::avx2_fixed_size<4>>,
simd<float, simd_abi::avx2_fixed_size<4>>> {
public:
using abi_type = simd_abi::avx2_fixed_size<4>;
using value_type = simd<float, abi_type>;
using mask_type = simd_mask<float, abi_type>;

protected:
value_type& m_value;
mask_type const& m_mask;

public:
const_where_expression(mask_type const& mask_arg, value_type const& value_arg)
: m_value(const_cast<value_type&>(value_arg)), m_mask(mask_arg) {}

KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_to(float* mem, element_aligned_tag) const {
_mm_maskstore_ps(mem, _mm_castps_si128(static_cast<__m128>(m_mask)),
static_cast<__m128>(m_value));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void scatter_to(
float* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) const {
for (std::size_t lane = 0; lane < 4; ++lane) {
if (m_mask[lane]) mem[index[lane]] = m_value[lane];
}
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type const&
impl_get_value() const {
return m_value;
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type const&
impl_get_mask() const {
return m_mask;
}
};

template <>
class where_expression<simd_mask<float, simd_abi::avx2_fixed_size<4>>,
simd<float, simd_abi::avx2_fixed_size<4>>>
: public const_where_expression<
simd_mask<float, simd_abi::avx2_fixed_size<4>>,
simd<float, simd_abi::avx2_fixed_size<4>>> {
public:
where_expression(
simd_mask<float, simd_abi::avx2_fixed_size<4>> const& mask_arg,
simd<float, simd_abi::avx2_fixed_size<4>>& value_arg)
: const_where_expression(mask_arg, value_arg) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_from(float const* mem, element_aligned_tag) {
m_value = value_type(
_mm_maskload_ps(mem, _mm_castps_si128(static_cast<__m128>(m_mask))));
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
float const* mem,
simd<std::int32_t, simd_abi::avx2_fixed_size<4>> const& index) {
m_value = value_type(_mm_mask_i32gather_ps(static_cast<__m128>(m_value),
mem, static_cast<__m128i>(index),
static_cast<__m128>(m_mask), 4));
}
template <class U,
std::enable_if_t<std::is_convertible_v<
U, simd<float, simd_abi::avx2_fixed_size<4>>>,
bool> = false>
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void operator=(U&& x) {
auto const x_as_value_type =
static_cast<simd<float, simd_abi::avx2_fixed_size<4>>>(
std::forward<U>(x));
m_value = simd<float, simd_abi::avx2_fixed_size<4>>(_mm_blendv_ps(
static_cast<__m128>(m_value), static_cast<__m128>(x_as_value_type),
static_cast<__m128>(m_mask)));
}
};

template <>
class const_where_expression<
simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>>,
Expand Down

0 comments on commit e76708a

Please sign in to comment.