Skip to content

Commit

Permalink
core/src: Add half math functions to private header (kokkos#6124)
Browse files Browse the repository at this point in the history
* core/src: Add half math functions to private header

* Update formatting

* core/src: Formatting half math fns

* core/src: proposed math fns testing

* core/src/impl: Organize half math fns

* core/unit_test: Add half math trig fn tests

* core/unit_test: Cleanup with KE alias

* core/unit_test: Add half math power fn tests

* core/unit_test: Add half math exponential fn tests

* core/unit_test: Add half math hyperbolic fn tests

* core/unit_test: Add half math error and gamma fn tests

* core/unit_test: Add half math nearest integer floating point fn tests

* core/unit_test: Add half math floating point manipulation fn tests

* core/unit_test: Add math basic operation fn tests

* core/unit_test: Add half math isnan test and placeholders

* core/unit_test: Fix clang+serial build

* core/unit_test: Attempt to ignore -Wno-gnu-zero-variadic-macro-arguments

* Add include guards around diagnostic pragma

* Fix warning flag

* Disable remainder coverage

* Include pragma for other LLVM-based compilers

* core/unit_test: Workaround SYCL and MSVC build errors

* core/unit_test: nearbyint not available in sycl

* core: Disable isnan for half_t on SYCL

* core: Disable isnan for half_t on HIP

* Update core/src/impl/Kokkos_Half_MathematicalFunctions.hpp

Simplify KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER

Co-authored-by: Daniel Arndt <arndtd@ornl.gov>

* core/src: Fix half math func wrapper

* [ci skip] Fix typo

* Address Damiens feedback

* Add kokkos_type_is_half_t and kokkos_type_is_bhalf_t

* Update core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp

Co-authored-by: Daniel Arndt <arndtd@ornl.gov>

* Only cast if necessary

* Fix formating

* Rename and move half traits

* Add patch from Damien

* Add runtime copy from Damien

* Update TODOs

* Update core/unit_test/TestMathematicalFunctions.hpp

Co-authored-by: Damien L-G <dalg24+github@gmail.com>

* Update core/unit_test/TestMathematicalFunctions.hpp

Co-authored-by: Damien L-G <dalg24+github@gmail.com>

* Update int,half binary fns to use float

* Add integral overloads

* Add float overloads

---------

Co-authored-by: Daniel Arndt <arndtd@ornl.gov>
Co-authored-by: Damien L-G <dalg24+github@gmail.com>
  • Loading branch information
3 people committed Sep 7, 2023
1 parent fc213ea commit 1affb05
Show file tree
Hide file tree
Showing 4 changed files with 705 additions and 52 deletions.
1 change: 1 addition & 0 deletions core/src/Kokkos_Half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <impl/Kokkos_Half_FloatingPointWrapper.hpp>
#include <impl/Kokkos_Half_NumericTraits.hpp>
#include <impl/Kokkos_Half_MathematicalFunctions.hpp>

#ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
#undef KOKKOS_IMPL_PUBLIC_INCLUDE
Expand Down
21 changes: 20 additions & 1 deletion core/src/impl/Kokkos_Half_FloatingPointWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
#include <iosfwd> // istream & ostream for extraction and insertion ops
#include <string>

namespace Kokkos::Experimental::Impl {
/// @brief templated struct for determining if half_t is an alias to float.
/// @tparam T The type to specialize on.
template <class T>
struct is_float16 : std::false_type {};

/// @brief templated struct for determining if bhalf_t is an alias to float.
/// @tparam T The type to specialize on.
template <class T>
struct is_bfloat16 : std::false_type {};
} // namespace Kokkos::Experimental::Impl

#ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED

// KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH: A macro to select which
Expand All @@ -44,6 +56,10 @@ class floating_point_wrapper;
// Declare half_t (binary16)
using half_t = Kokkos::Experimental::Impl::floating_point_wrapper<
Kokkos::Impl::half_impl_t ::type>;
namespace Impl {
template <>
struct is_float16<half_t> : std::true_type {};
} // namespace Impl
KOKKOS_INLINE_FUNCTION
half_t cast_to_half(float val);
KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -110,7 +126,10 @@ KOKKOS_INLINE_FUNCTION
#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
using bhalf_t = Kokkos::Experimental::Impl::floating_point_wrapper<
Kokkos::Impl ::bhalf_impl_t ::type>;

namespace Impl {
template <>
struct is_bfloat16<bhalf_t> : std::true_type {};
} // namespace Impl
KOKKOS_INLINE_FUNCTION
bhalf_t cast_to_bhalf(float val);
KOKKOS_INLINE_FUNCTION
Expand Down
191 changes: 191 additions & 0 deletions core/src/impl/Kokkos_Half_MathematicalFunctions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_
#define KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_

#include <Kokkos_MathematicalFunctions.hpp> // For the float overloads

// clang-format off
namespace Kokkos {
// BEGIN macro definitions
#if defined(KOKKOS_HALF_T_IS_FLOAT) && !KOKKOS_HALF_T_IS_FLOAT
#define KOKKOS_IMPL_MATH_H_FUNC_WRAPPER(MACRO, FUNC) \
MACRO(FUNC, Kokkos::Experimental::half_t)
#else
#define KOKKOS_IMPL_MATH_H_FUNC_WRAPPER(MACRO, FUNC)
#endif

#if defined(KOKKOS_BHALF_T_IS_FLOAT) && !KOKKOS_BHALF_T_IS_FLOAT
#define KOKKOS_IMPL_MATH_B_FUNC_WRAPPER(MACRO, FUNC) \
MACRO(FUNC, Kokkos::Experimental::bhalf_t)
#else
#define KOKKOS_IMPL_MATH_B_FUNC_WRAPPER(MACRO, FUNC)
#endif

#define KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(MACRO, FUNC) \
KOKKOS_IMPL_MATH_H_FUNC_WRAPPER(MACRO, FUNC) \
KOKKOS_IMPL_MATH_B_FUNC_WRAPPER(MACRO, FUNC)


#define KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION HALF_TYPE FUNC(HALF_TYPE x) { \
return static_cast<HALF_TYPE>(Kokkos::FUNC(static_cast<float>(x))); \
}

#define KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, MIXED_TYPE) \
KOKKOS_INLINE_FUNCTION double FUNC(HALF_TYPE x, MIXED_TYPE y) { \
return Kokkos::FUNC(static_cast<double>(x), static_cast<double>(y)); \
} \
KOKKOS_INLINE_FUNCTION double FUNC(MIXED_TYPE x, HALF_TYPE y) { \
return Kokkos::FUNC(static_cast<double>(x), static_cast<double>(y)); \
}

#define KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION HALF_TYPE FUNC(HALF_TYPE x, HALF_TYPE y) { \
return static_cast<HALF_TYPE>( \
Kokkos::FUNC(static_cast<float>(x), static_cast<float>(y))); \
} \
KOKKOS_INLINE_FUNCTION float FUNC(float x, HALF_TYPE y) { \
return Kokkos::FUNC(static_cast<float>(x), static_cast<float>(y)); \
} \
KOKKOS_INLINE_FUNCTION float FUNC(HALF_TYPE x, float y) { \
return Kokkos::FUNC(static_cast<float>(x), static_cast<float>(y)); \
} \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, double) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, short) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned short) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, int) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned int) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, long) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned long) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, long long) \
KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF_MIXED(FUNC, HALF_TYPE, unsigned long long)


#define KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION bool FUNC(HALF_TYPE x) { \
return Kokkos::FUNC(static_cast<float>(x)); \
}

// END macros definitions


// Basic operations
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, abs)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, fabs)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, fmod)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, remainder)
// remquo
// fma
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, fmax)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, fmin)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, fdim)
// nanq
// Exponential functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, exp)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, exp2)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, expm1)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, log)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, log10)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, log2)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, log1p)
// Power functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, pow)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, sqrt)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, cbrt)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, hypot)
// Trigonometric functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, sin)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, cos)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, tan)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, asin)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, acos)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, atan)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, atan2)
// Hyperbolic functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, sinh)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, cosh)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, tanh)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, asinh)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, acosh)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, atanh)
// Error and gamma functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, erf)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, erfc)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, tgamma)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, lgamma)
// Nearest integer floating point functions
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, ceil)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, floor)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, trunc)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, round)
// lround
// llround
// FIXME_SYCL not available as of current SYCL 2020 specification (revision 4)
#ifndef KOKKOS_ENABLE_SYCL // FIXME_SYCL
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, nearbyint)
#endif
// rint
// lrint
// llrint
// Floating point manipulation functions
// frexp
// ldexp
// modf
// scalbn
// scalbln
// ilog
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE, logb)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, nextafter)
// nexttoward
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF, copysign)
// Classification and comparison functions
// fpclassify
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isfinite)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isinf)
#if !defined(KOKKOS_ENABLE_SYCL) && !defined(KOKKOS_ENABLE_HIP) // FIXME_SYCL, FIXME_HIP
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, isnan)
#endif
// isnormal
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF, signbit)
// isgreater
// isgreaterequal
// isless
// islessequal
// islessgreater
// isunordered
// Complex number functions
#define KOKKOS_IMPL_MATH_COMPLEX_REAL_HALF(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION HALF_TYPE FUNC(HALF_TYPE x) { return x; }

#define KOKKOS_IMPL_MATH_COMPLEX_IMAG_HALF(FUNC, HALF_TYPE) \
KOKKOS_INLINE_FUNCTION HALF_TYPE FUNC(HALF_TYPE) { return 0; }

KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_COMPLEX_REAL_HALF, real)
KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER(KOKKOS_IMPL_MATH_COMPLEX_IMAG_HALF, imag)

#undef KOKKOS_IMPL_MATH_COMPLEX_REAL_HALF
#undef KOKKOS_IMPL_MATH_COMPLEX_IMAG_HALF
#undef KOKKOS_IMPL_MATH_UNARY_PREDICATE_HALF
#undef KOKKOS_IMPL_MATH_BINARY_FUNCTION_HALF
#undef KOKKOS_IMPL_MATH_UNARY_FUNCTION_HALF_TYPE
#undef KOKKOS_IMPL_MATH_HALF_FUNC_WRAPPER
#undef KOKKOS_IMPL_MATH_B_FUNC_WRAPPER
#undef KOKKOS_IMPL_MATH_H_FUNC_WRAPPER
} // namespace Kokkos
// clang-format on
#endif // KOKKOS_HALF_MATHEMATICAL_FUNCTIONS_HPP_

0 comments on commit 1affb05

Please sign in to comment.