Skip to content

Commit

Permalink
Fix infinity, quiet_NaN, signaling_Nan, isfinite, isnan, isinf for ha…
Browse files Browse the repository at this point in the history
…lf_t and bhalf_t (kokkos#6543)

* Fix nvcc warning for non-trivial types in bit_cast

* Introduce BitComparisonWrapper

* Implement isnan, isfinite, isinf for half_t, bhalf_t with bit comparison

* Fix infinity, quiet_NaN, signaling_NaN for half_t, bhalf_t

* Improve tests

* Disable TestCuda_WithoutInitializing for NVHPC

* Define exponent/fraction_mask in FloatingPointWrapper.hpp

* Minimize changes to TestMathematicalFunctions.hpp

* Enable tests for inf, quiet_nan, signaling_nan for half_t and bhalf_t

* Don't repeat storage class specifier in template specialization

* Try inline constexpr and move definitons for the same type together

* Disable numeric traits unit tests for NVHPC

* Define comparison operators for BitComparisonWrapper

* Fix TestNumericTraits, no constexpr consructor for [b]half_t
  • Loading branch information
masterleinad committed Nov 10, 2023
1 parent 97a90d5 commit d5a4802
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 36 deletions.
5 changes: 5 additions & 0 deletions containers/unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ foreach(Tag Threads;Serial;OpenMP;HPX;Cuda;HIP;SYCL)
LIST(REMOVE_ITEM UnitTestSources ${dir}/TestCuda_DynViewAPI_generic.cpp)
endif()

# FIXME_NVHPC: NVC++-S-0000-Internal compiler error. extractor: bad opc 0
if(KOKKOS_ENABLE_CUDA AND KOKKOS_CXX_COMPILER_ID STREQUAL NVHPC)
LIST(REMOVE_ITEM UnitTestSources ${dir}/TestCuda_WithoutInitializing.cpp)
endif()

KOKKOS_ADD_EXECUTABLE_AND_TEST(ContainersUnitTest_${Tag} SOURCES ${UnitTestSources})
endif()
endforeach()
Expand Down
2 changes: 1 addition & 1 deletion core/src/Kokkos_BitManipulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ bit_cast(From const& from) noexcept {
return sycl::bit_cast<To>(from);
#else
To to;
memcpy(&to, &from, sizeof(To));
memcpy(static_cast<void*>(&to), static_cast<const void*>(&from), sizeof(To));
return to;
#endif
}
Expand Down
68 changes: 67 additions & 1 deletion core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define KOKKOS_HALF_FLOATING_POINT_WRAPPER_HPP_

#include <Kokkos_Macros.hpp>
#include <Kokkos_BitManipulation.hpp> // bit_cast

#include <type_traits>
#include <iosfwd> // istream & ostream for extraction and insertion ops
Expand Down Expand Up @@ -215,10 +216,70 @@ cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x);
/************************** END forward declarations **************************/

namespace Impl {

template <typename FloatType>
struct BitComparisonWrapper {
std::uint16_t value;

template <typename Number>
KOKKOS_FUNCTION friend bool operator==(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) == b;
}

template <typename Number>
KOKKOS_FUNCTION friend bool operator!=(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) != b;
}

template <typename Number>
KOKKOS_FUNCTION friend bool operator<(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) < b;
}

template <typename Number>
KOKKOS_FUNCTION friend bool operator<=(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) <= b;
}

template <typename Number>
KOKKOS_FUNCTION friend bool operator>(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) > b;
}

template <typename Number>
KOKKOS_FUNCTION friend bool operator>=(BitComparisonWrapper a, Number b) {
return static_cast<FloatType>(a) >= b;
}
};

template <typename FloatType>
inline constexpr BitComparisonWrapper<FloatType> exponent_mask;
template <typename FloatType>
inline constexpr BitComparisonWrapper<FloatType> fraction_mask;

#ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED
template <>
inline constexpr BitComparisonWrapper<Kokkos::Experimental::half_t>
exponent_mask<Kokkos::Experimental::half_t>{0b0'11111'0000000000};
template <>
inline constexpr BitComparisonWrapper<Kokkos::Experimental::half_t>
fraction_mask<Kokkos::Experimental::half_t>{0b0'00000'1111111111};
#endif

#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
template <>
inline constexpr BitComparisonWrapper<Kokkos::Experimental::bhalf_t>
exponent_mask<Kokkos::Experimental::bhalf_t>{0b0'11111111'0000000};
template <>
inline constexpr BitComparisonWrapper<Kokkos::Experimental::bhalf_t>
fraction_mask<Kokkos::Experimental::bhalf_t>{0b0'00000000'1111111};
#endif

template <class FloatType>
class alignas(FloatType) floating_point_wrapper {
public:
using impl_type = FloatType;
using impl_type = FloatType;
using bit_comparison_type = BitComparisonWrapper<floating_point_wrapper>;

private:
impl_type val;
Expand Down Expand Up @@ -269,6 +330,11 @@ class alignas(FloatType) floating_point_wrapper {
#endif // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
}

KOKKOS_FUNCTION
floating_point_wrapper(bit_comparison_type rhs) {
val = Kokkos::bit_cast<impl_type>(rhs);
}

// Don't support implicit conversion back to impl_type.
// impl_type is a storage only type on host.
KOKKOS_FUNCTION
Expand Down
80 changes: 74 additions & 6 deletions core/src/impl/Kokkos_Half_MathematicalFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_

#include <Kokkos_MathematicalFunctions.hpp> // For the float overloads
#include <Kokkos_BitManipulation.hpp> // bit_cast

// clang-format off
namespace Kokkos {
Expand Down Expand Up @@ -74,7 +75,7 @@ namespace Kokkos {
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned long) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, long long) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned long long)


#define KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION bool FUNC(HALF_TYPE x) { \
Expand Down Expand Up @@ -155,10 +156,77 @@ KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, nextaf
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, copysign)
// Classification and comparison functions
// fpclassify
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isfinite)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isinf)
#if !defined(KOKKOS_ENABLE_SYCL) && !defined(KOKKOS_ENABLE_HIP) // FIXME_SYCL, FIXME_HIP
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isnan)

#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isfinite(Kokkos::Experimental::half_t x) {
using bit_type = Kokkos::Experimental::half_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::half_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::half_t::impl_type>(x));
return (bit_pattern_x.value & exponent_mask.value) != exponent_mask.value;
}
#endif

#if defined(KOKKOS_BHALF_T_IS_FLOAT) && !KOKKOS_BHALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isfinite(Kokkos::Experimental::bhalf_t x) {
using bit_type = Kokkos::Experimental::bhalf_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::bhalf_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::bhalf_t::impl_type>(x));
return (bit_pattern_x.value & exponent_mask.value) != exponent_mask.value;
}
#endif

#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isinf(Kokkos::Experimental::half_t x) {
using bit_type = Kokkos::Experimental::half_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::half_t>;
constexpr bit_type fraction_mask = Kokkos::Experimental::Impl::fraction_mask<Kokkos::Experimental::half_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::half_t::impl_type>(x));
return (
((bit_pattern_x.value & exponent_mask.value) == exponent_mask.value) &&
((bit_pattern_x.value & fraction_mask.value) == 0));
}
#endif

#if defined(KOKKOS_BHALF_T_IS_FLOAT) && !KOKKOS_BHALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isinf(Kokkos::Experimental::bhalf_t x) {
using bit_type = Kokkos::Experimental::bhalf_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::bhalf_t>;
constexpr bit_type fraction_mask = Kokkos::Experimental::Impl::fraction_mask<Kokkos::Experimental::bhalf_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::bhalf_t::impl_type>(x));
return (
((bit_pattern_x.value & exponent_mask.value) == exponent_mask.value) &&
((bit_pattern_x.value & fraction_mask.value) == 0));
}
#endif

#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isnan(Kokkos::Experimental::half_t x) {
using bit_type = Kokkos::Experimental::half_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::half_t>;
constexpr bit_type fraction_mask = Kokkos::Experimental::Impl::fraction_mask<Kokkos::Experimental::half_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::half_t::impl_type>(x));
return (
((bit_pattern_x.value & exponent_mask.value) == exponent_mask.value) &&
((bit_pattern_x.value & fraction_mask.value) != 0));
}
#endif

#if defined(KOKKOS_BHALF_T_IS_FLOAT) && !KOKKOS_BHALF_T_IS_FLOAT
KOKKOS_INLINE_FUNCTION bool isnan(Kokkos::Experimental::bhalf_t x) {
using bit_type = Kokkos::Experimental::bhalf_t::bit_comparison_type;
constexpr bit_type exponent_mask = Kokkos::Experimental::Impl::exponent_mask<Kokkos::Experimental::bhalf_t>;
constexpr bit_type fraction_mask = Kokkos::Experimental::Impl::fraction_mask<Kokkos::Experimental::bhalf_t>;
const bit_type bit_pattern_x = bit_cast<bit_type>(
static_cast<Kokkos::Experimental::bhalf_t::impl_type>(x));
return (
((bit_pattern_x.value & exponent_mask.value) == exponent_mask.value) &&
((bit_pattern_x.value & fraction_mask.value) != 0));
}
#endif
// isnormal
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, signbit)
Expand Down Expand Up @@ -188,4 +256,4 @@ KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_COMPLEX_IMAG_HALF, imag)
#undef KOKKOS_IMPL_MATH_H_FUNC_WRAPPER
} // namespace Kokkos
// clang-format on
#endif // KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_
#endif // KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_
20 changes: 10 additions & 10 deletions core/src/impl/Kokkos_Half_NumericTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
///
template <>
struct Kokkos::Experimental::Impl::infinity_helper<Kokkos::Experimental::half_t> {
static constexpr int value = 0x7C00;
static constexpr Kokkos::Experimental::half_t::bit_comparison_type value{0b0'11111'0000000000};
};

/// \brief: Minimum normalized number
Expand Down Expand Up @@ -157,30 +157,30 @@ struct Kokkos::Experimental::Impl::norm_min_helper<

/// \brief: Quiet not a half precision number
///
/// IEEE 754 defines this as all exponent bits high.
/// IEEE 754 defines this as all exponent bits and the first fraction bit high.
///
/// Quiet NaN in binary16:
/// [s e e e e e f f f f f f f f f f]
/// [1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0]
/// [0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]
/// bit index: 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0
template <>
struct Kokkos::Experimental::Impl::quiet_NaN_helper<
Kokkos::Experimental::half_t> {
static constexpr float value = 0xfc000;
static constexpr Kokkos::Experimental::half_t::bit_comparison_type value{0b0'11111'1000000000};
};

/// \brief: Signaling not a half precision number
///
/// IEEE 754 defines this as all exponent bits and the first fraction bit high.
/// IEEE 754 defines this as all exponent bits and the second fraction bit high.
///
/// Quiet NaN in binary16:
/// [s e e e e e f f f f f f f f f f]
/// [1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]
/// [0 1 1 1 1 1 0 1 0 0 0 0 0 0 0 0]
/// bit index: 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0
template <>
struct Kokkos::Experimental::Impl::signaling_NaN_helper<
Kokkos::Experimental::half_t> {
static constexpr float value = 0xfe000;
static constexpr Kokkos::Experimental::half_t::bit_comparison_type value{0b0'11111'0100000000};
};

/// \brief: Number of digits in the matissa that can be represented
Expand Down Expand Up @@ -267,7 +267,7 @@ struct Kokkos::Experimental::Impl::max_exponent_helper<
///
template <>
struct Kokkos::Experimental::Impl::infinity_helper<Kokkos::Experimental::bhalf_t> {
static constexpr int value = 0x7F80;
static constexpr Kokkos::Experimental::bhalf_t::bit_comparison_type value{0b0'11111111'0000000};
};

// Minimum normalized number
Expand Down Expand Up @@ -303,13 +303,13 @@ struct Kokkos::Experimental::Impl::norm_min_helper<
template <>
struct Kokkos::Experimental::Impl::quiet_NaN_helper<
Kokkos::Experimental::bhalf_t> {
static constexpr float value = 0x7fc000;
static constexpr Kokkos::Experimental::bhalf_t::bit_comparison_type value{0b0'11111111'1000000};
};
// Signaling not a bhalf number
template <>
struct Kokkos::Experimental::Impl::signaling_NaN_helper<
Kokkos::Experimental::bhalf_t> {
static constexpr float value = 0x7fe000;
static constexpr Kokkos::Experimental::bhalf_t::bit_comparison_type value{0b0'11111111'0100000};
};
// Number of digits in the matissa that can be represented
// without losing precision.
Expand Down

0 comments on commit d5a4802

Please sign in to comment.