-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update common mathematical functions #4043
Conversation
ea15a29
to
5435c6a
Compare
Whats the status of this? |
5435c6a
to
dc2933d
Compare
Ready for review |
@lucbv please review @masterleinad or @nliber please review the adding a definition of NaN for SYCL |
50d39d0
to
56d69f0
Compare
56d69f0
to
c184bb4
Compare
KOKKOS_INLINE_FUNCTION RETURNTYPE FUNC(ARGTYPE x) { \ | ||
using NAMESPACE_MATH_FUNCTIONS::FUNC; \ | ||
return FUNC(x); \ | ||
#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there overloads other than for long double
guarded here? I would prefer to avoid the duplicated code if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What alternative do you suggest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like
diff --git a/core/src/Kokkos_MathematicalFunctions.hpp b/core/src/Kokkos_MathematicalFunctions.hpp
index 50fde82d7..8c378559b 100644
--- a/core/src/Kokkos_MathematicalFunctions.hpp
+++ b/core/src/Kokkos_MathematicalFunctions.hpp
@@ -100,9 +100,7 @@ namespace Experimental {
#define KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS
#endif
-#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
-
-#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC) \
KOKKOS_INLINE_FUNCTION float FUNC(float x) { \
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x); \
@@ -111,18 +109,10 @@ namespace Experimental {
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x); \
} \
- KOKKOS_INLINE_FUNCTION long double FUNC(long double x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
KOKKOS_INLINE_FUNCTION float FUNC##f(float x) { \
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x); \
} \
- KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
template <class T> \
KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, double> \
FUNC(T x) { \
@@ -130,7 +120,23 @@ namespace Experimental {
return FUNC(static_cast<double>(x)); \
}
-#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
+ KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC) \
+ KOKKOS_INLINE_FUNCTION long double FUNC(long double x) { \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(x); \
+ } \
+ KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x) { \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(x); \
+ } \
+#else
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
+ KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC)
+#endif
+
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC) \
KOKKOS_INLINE_FUNCTION bool FUNC(float x) { \
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x); \
@@ -139,10 +145,6 @@ namespace Experimental {
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x); \
} \
- KOKKOS_INLINE_FUNCTION bool FUNC(long double x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
template <class T> \
KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, bool> \
FUNC(T x) { \
@@ -150,7 +152,19 @@ namespace Experimental {
return FUNC(static_cast<double>(x)); \
}
-#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC) \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
+ KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC) \
+ KOKKOS_INLINE_FUNCTION bool FUNC(long double x) { \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(x); \
+ }
+#else
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
+ KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC)
+#endif
+
+#define KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC) \
KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) { \
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(x, y); \
@@ -181,57 +195,29 @@ namespace Experimental {
return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
}
-#else // long double overloads are not available
-
-#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
- KOKKOS_INLINE_FUNCTION float FUNC(float x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
- KOKKOS_INLINE_FUNCTION double FUNC(double x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
- KOKKOS_INLINE_FUNCTION float FUNC##f(float x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
- template <class T> \
- KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, double> \
- FUNC(T x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(static_cast<double>(x)); \
- }
-
-#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
- KOKKOS_INLINE_FUNCTION bool FUNC(float x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
- KOKKOS_INLINE_FUNCTION bool FUNC(double x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x); \
- } \
- template <class T> \
- KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, bool> \
- FUNC(T x) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(static_cast<double>(x)); \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC) \
+ KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC) \
+ KOKKOS_INLINE_FUNCTION long double FUNC(long double x, long double y) { \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(x, y); \
+ } \
+ KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x, long double y) { \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(x, y); \
+ } \
+ template <class T1, class T2> \
+ KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_arithmetic<T1>::value && \
+ std::is_arithmetic<T2>::value, \
+ Kokkos::Impl::promote_2_t<T1, T2>> \
+ FUNC(T1 x, T2 y) { \
+ using Promoted = Kokkos::Impl::promote_2_t<T1, T2>; \
+ using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+ return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
}
-
+#else
#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC) \
- KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x, y); \
- } \
- KOKKOS_INLINE_FUNCTION double FUNC(double x, double y) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x, y); \
- } \
- KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y) { \
- using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
- return FUNC(x, y); \
- } \
+ KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC) \
template <class T1, class T2> \
KOKKOS_INLINE_FUNCTION std::enable_if_t< \
std::is_arithmetic<T1>::value && std::is_arithmetic<T2>::value && \
@@ -243,7 +229,6 @@ namespace Experimental {
using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
}
-
#endif
// Basic operations
optionally combing all the branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah I don't like it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's your argument for having the duplications then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's what I was trying to move away from with the refactor. The goal was less macro to reduce complexity and increase readability, even if it means duplicating code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we disagree on this point here then. 🙂
Retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes @dalg24 this looks good
*f
and*l
"overloads" forfloat
andlong double