Skip to content

Commit

Permalink
core/src: Add half single and double mixed compare (LT,GT,LE,GE) (kok…
Browse files Browse the repository at this point in the history
…kos#6407)

* core/src: Add half single and double mixed compare (LT,GT,LE,GE)

* Implement PR feedback:

  - Check whether T is convertible to float
  - Try upcasting floating_point_wrapper to float and relying on
  the toolchains implicit upcasting to kick in
  - Try comparing impl_type with T if impl_type is a full type

* Add missing endif

* Add missing ifdefs

* Update core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp

Co-authored-by: Daniel Arndt <arndtd@ornl.gov>

* Remove HALF_IS_FULL_TYPE branch

---------

Co-authored-by: Daniel Arndt <arndtd@ornl.gov>
  • Loading branch information
e10harvey and masterleinad committed Oct 4, 2023
1 parent e04f637 commit 2075ae7
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 27 deletions.
72 changes: 72 additions & 0 deletions core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,99 @@ class alignas(FloatType) floating_point_wrapper {
return tmp_lhs < tmp_rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator<(floating_point_wrapper lhs, T rhs) {
return static_cast<float>(lhs) < rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator<(T lhs, floating_point_wrapper rhs) {
return lhs < static_cast<float>(rhs);
}

KOKKOS_FUNCTION
friend bool operator>(const volatile floating_point_wrapper& lhs,
const volatile floating_point_wrapper& rhs) {
floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
return tmp_lhs > tmp_rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator>(floating_point_wrapper lhs, T rhs) {
return static_cast<float>(lhs) > rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator>(T lhs, floating_point_wrapper rhs) {
return lhs > static_cast<float>(rhs);
}

KOKKOS_FUNCTION
friend bool operator<=(const volatile floating_point_wrapper& lhs,
const volatile floating_point_wrapper& rhs) {
floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
return tmp_lhs <= tmp_rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator<=(floating_point_wrapper lhs, T rhs) {
return static_cast<float>(lhs) <= rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator<=(T lhs, floating_point_wrapper rhs) {
return lhs <= static_cast<float>(rhs);
}

KOKKOS_FUNCTION
friend bool operator>=(const volatile floating_point_wrapper& lhs,
const volatile floating_point_wrapper& rhs) {
floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
return tmp_lhs >= tmp_rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator>=(floating_point_wrapper lhs, T rhs) {
return static_cast<float>(lhs) >= rhs;
}

template <class T>
KOKKOS_FUNCTION friend std::enable_if_t<std::is_convertible_v<T, float> &&
(std::is_same_v<T, float> ||
std::is_same_v<T, double>),
bool>
operator>=(T lhs, floating_point_wrapper rhs) {
return lhs >= static_cast<float>(rhs);
}

// Insertion and extraction operators
friend std::ostream& operator<<(std::ostream& os,
const floating_point_wrapper& x) {
Expand Down
103 changes: 76 additions & 27 deletions core/unit_test/TestHalfOperators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,27 @@ enum OP_TESTS {
OR,
EQ,
NEQ,
LT,
GT,
LE,
GE, // TODO: TW,
LT_H_H,
LT_H_S,
LT_S_H,
LT_H_D,
LT_D_H,
GT_H_H,
GT_H_S,
GT_S_H,
GT_H_D,
GT_D_H,
LE_H_H,
LE_H_S,
LE_S_H,
LE_H_D,
LE_D_H,
GE_H_H,
GE_H_S,
GE_S_H,
GE_H_D,
GE_D_H,
// TODO: TW,
PASS_BY_REF,
AO_IMPL_HALF,
AO_HALF_T,
Expand Down Expand Up @@ -292,20 +309,20 @@ struct Functor_TestHalfVolatileOperators {
actual_lhs(ASSIGN) = static_cast<double>(nv_tmp);
expected_lhs(ASSIGN) = d_lhs;

actual_lhs(LT) = h_lhs < h_rhs;
expected_lhs(LT) = d_lhs < d_rhs;
actual_lhs(LT_H_H) = h_lhs < h_rhs;
expected_lhs(LT_H_H) = d_lhs < d_rhs;

actual_lhs(LE) = h_lhs <= h_rhs;
expected_lhs(LE) = d_lhs <= d_rhs;
actual_lhs(LE_H_H) = h_lhs <= h_rhs;
expected_lhs(LE_H_H) = d_lhs <= d_rhs;

actual_lhs(NEQ) = h_lhs != h_rhs;
expected_lhs(NEQ) = d_lhs != d_rhs;

actual_lhs(GT) = h_lhs > h_rhs;
expected_lhs(GT) = d_lhs > d_rhs;
actual_lhs(GT_H_H) = h_lhs > h_rhs;
expected_lhs(GT_H_H) = d_lhs > d_rhs;

actual_lhs(GE) = h_lhs >= h_rhs;
expected_lhs(GE) = d_lhs >= d_rhs;
actual_lhs(GE_H_H) = h_lhs >= h_rhs;
expected_lhs(GE_H_H) = d_lhs >= d_rhs;

actual_lhs(EQ) = h_lhs == h_rhs;
expected_lhs(EQ) = d_lhs == d_rhs;
Expand Down Expand Up @@ -879,17 +896,49 @@ struct Functor_TestHalfOperators {
actual_lhs(NEQ) = h_lhs != h_rhs;
expected_lhs(NEQ) = d_lhs != d_rhs;

actual_lhs(LT) = h_lhs < h_rhs;
expected_lhs(LT) = d_lhs < d_rhs;

actual_lhs(GT) = h_lhs > h_rhs;
expected_lhs(GT) = d_lhs > d_rhs;

actual_lhs(LE) = h_lhs <= h_rhs;
expected_lhs(LE) = d_lhs <= d_rhs;

actual_lhs(GE) = h_lhs >= h_rhs;
expected_lhs(GE) = d_lhs >= d_rhs;
actual_lhs(LT_H_H) = h_lhs < h_rhs;
expected_lhs(LT_H_H) = d_lhs < d_rhs;
actual_lhs(LT_H_S) = h_lhs < static_cast<float>(h_rhs);
expected_lhs(LT_H_S) = d_lhs < d_rhs;
actual_lhs(LT_S_H) = static_cast<float>(h_lhs) < h_rhs;
expected_lhs(LT_S_H) = d_lhs < d_rhs;
actual_lhs(LT_H_D) = h_lhs < static_cast<double>(h_rhs);
expected_lhs(LT_H_D) = d_lhs < d_rhs;
actual_lhs(LT_D_H) = static_cast<double>(h_lhs) < h_rhs;
expected_lhs(LT_D_H) = d_lhs < d_rhs;

actual_lhs(GT_H_H) = h_lhs > h_rhs;
expected_lhs(GT_H_H) = d_lhs > d_rhs;
actual_lhs(GT_H_S) = h_lhs > static_cast<float>(h_rhs);
expected_lhs(GT_H_S) = d_lhs > d_rhs;
actual_lhs(GT_S_H) = static_cast<float>(h_lhs) > h_rhs;
expected_lhs(GT_S_H) = d_lhs > d_rhs;
actual_lhs(GT_H_D) = h_lhs > static_cast<double>(h_rhs);
expected_lhs(GT_H_D) = d_lhs > d_rhs;
actual_lhs(GT_D_H) = static_cast<double>(h_lhs) > h_rhs;
expected_lhs(GT_D_H) = d_lhs > d_rhs;

actual_lhs(LE_H_H) = h_lhs <= h_rhs;
expected_lhs(LE_H_H) = d_lhs <= d_rhs;
actual_lhs(LE_H_S) = h_lhs <= static_cast<float>(h_rhs);
expected_lhs(LE_H_S) = d_lhs <= d_rhs;
actual_lhs(LE_S_H) = static_cast<float>(h_lhs) <= h_rhs;
expected_lhs(LE_S_H) = d_lhs <= d_rhs;
actual_lhs(LE_H_D) = h_lhs <= static_cast<double>(h_rhs);
expected_lhs(LE_H_D) = d_lhs <= d_rhs;
actual_lhs(LE_D_H) = static_cast<double>(h_lhs) <= h_rhs;
expected_lhs(LE_D_H) = d_lhs <= d_rhs;

actual_lhs(GE_H_H) = h_lhs >= h_rhs;
expected_lhs(GE_H_H) = d_lhs >= d_rhs;
actual_lhs(GE_H_S) = h_lhs >= static_cast<float>(h_rhs);
expected_lhs(GE_H_S) = d_lhs >= d_rhs;
actual_lhs(GE_S_H) = static_cast<float>(h_lhs) >= h_rhs;
expected_lhs(GE_S_H) = d_lhs >= d_rhs;
actual_lhs(GE_H_D) = h_lhs >= static_cast<double>(h_rhs);
expected_lhs(GE_H_D) = d_lhs >= d_rhs;
actual_lhs(GE_D_H) = static_cast<double>(h_lhs) >= h_rhs;
expected_lhs(GE_D_H) = d_lhs >= d_rhs;

// actual_lhs(TW) = h_lhs <=> h_rhs; // Need C++20?
// expected_lhs(TW) = d_lhs <=> d_rhs; // Need C++20?
Expand Down Expand Up @@ -961,10 +1010,10 @@ void __test_half_operators(half_type h_lhs, half_type h_rhs) {
Kokkos::deep_copy(f_device_expected_lhs, f_device.expected_lhs);
for (int op_test = 0; op_test < N_OP_TESTS; op_test++) {
// printf("op_test = %d\n", op_test);
if (op_test == ASSIGN || op_test == LT || op_test == LE || op_test == NEQ ||
op_test == EQ || op_test == GT || op_test == GE ||
op_test == CADD_H_H || op_test == CSUB_H_H || op_test == CMUL_H_H ||
op_test == CDIV_H_H) {
if (op_test == ASSIGN || op_test == LT_H_H || op_test == LE_H_H ||
op_test == NEQ || op_test == EQ || op_test == GT_H_H ||
op_test == GE_H_H || op_test == CADD_H_H || op_test == CSUB_H_H ||
op_test == CMUL_H_H || op_test == CDIV_H_H) {
ASSERT_NEAR(f_device_actual_lhs(op_test), f_device_expected_lhs(op_test),
epsilon);
ASSERT_NEAR(f_host.actual_lhs(op_test), f_host.expected_lhs(op_test),
Expand Down

0 comments on commit 2075ae7

Please sign in to comment.