Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update common mathematical functions #4043

Merged
merged 9 commits into from
Jun 9, 2021

Conversation

dalg24
Copy link
Member

@dalg24 dalg24 commented May 22, 2021

  • Refactor definition of common math functions (attempt to reduce complexity and improve readability)
  • Add missing *f and *l "overloads" for float and long double
  • Implement function that return quiet NaN for SYCL

@dalg24 dalg24 force-pushed the fixup_common_math_functions branch 2 times, most recently from ea15a29 to 5435c6a Compare May 24, 2021 13:36
@dalg24 dalg24 marked this pull request as ready for review May 24, 2021 13:39
@dalg24 dalg24 marked this pull request as draft May 24, 2021 15:53
@crtrott
Copy link
Member

crtrott commented Jun 4, 2021

Whats the status of this?

@dalg24 dalg24 force-pushed the fixup_common_math_functions branch from 5435c6a to dc2933d Compare June 4, 2021 16:19
@dalg24 dalg24 changed the title Fix defect in common mathematical functions Update common mathematical functions Jun 4, 2021
@dalg24 dalg24 marked this pull request as ready for review June 4, 2021 16:20
@dalg24
Copy link
Member Author

dalg24 commented Jun 4, 2021

Whats the status of this?

Ready for review

@dalg24
Copy link
Member Author

dalg24 commented Jun 4, 2021

@lucbv please review

@masterleinad or @nliber please review the adding a definition of NaN for SYCL

@dalg24 dalg24 force-pushed the fixup_common_math_functions branch 3 times, most recently from 50d39d0 to 56d69f0 Compare June 4, 2021 17:45
@dalg24 dalg24 force-pushed the fixup_common_math_functions branch from 56d69f0 to c184bb4 Compare June 4, 2021 18:28
KOKKOS_INLINE_FUNCTION RETURNTYPE FUNC(ARGTYPE x) { \
using NAMESPACE_MATH_FUNCTIONS::FUNC; \
return FUNC(x); \
#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there overloads other than for long double guarded here? I would prefer to avoid the duplicated code if possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What alternative do you suggest?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

diff --git a/core/src/Kokkos_MathematicalFunctions.hpp b/core/src/Kokkos_MathematicalFunctions.hpp
index 50fde82d7..8c378559b 100644
--- a/core/src/Kokkos_MathematicalFunctions.hpp
+++ b/core/src/Kokkos_MathematicalFunctions.hpp
@@ -100,9 +100,7 @@ namespace Experimental {
 #define KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS
 #endif
 
-#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
-
-#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC)                                 \
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC)                           \
   KOKKOS_INLINE_FUNCTION float FUNC(float x) {                                \
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
     return FUNC(x);                                                           \
@@ -111,18 +109,10 @@ namespace Experimental {
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
     return FUNC(x);                                                           \
   }                                                                           \
-  KOKKOS_INLINE_FUNCTION long double FUNC(long double x) {                    \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(x);                                                           \
-  }                                                                           \
   KOKKOS_INLINE_FUNCTION float FUNC##f(float x) {                             \
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
     return FUNC(x);                                                           \
   }                                                                           \
-  KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x) {                 \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(x);                                                           \
-  }                                                                           \
   template <class T>                                                          \
   KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, double> \
   FUNC(T x) {                                                                 \
@@ -130,7 +120,23 @@ namespace Experimental {
     return FUNC(static_cast<double>(x));                                      \
   }
 
-#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC)                              \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC)                 \
+  KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC)                 \
+  KOKKOS_INLINE_FUNCTION long double FUNC(long double x) {    \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;         \
+    return FUNC(x);                                           \
+  }                                                           \
+  KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x) { \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;         \
+    return FUNC(x);                                           \
+  }                                                           \
+#else
+#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
+  KOKKOS_IMPL_MATH_UNARY_FUNCTION_OTHER(FUNC)
+#endif
+
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC)                        \
   KOKKOS_INLINE_FUNCTION bool FUNC(float x) {                               \
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
     return FUNC(x);                                                         \
@@ -139,10 +145,6 @@ namespace Experimental {
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
     return FUNC(x);                                                         \
   }                                                                         \
-  KOKKOS_INLINE_FUNCTION bool FUNC(long double x) {                         \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
-    return FUNC(x);                                                         \
-  }                                                                         \
   template <class T>                                                        \
   KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, bool> \
   FUNC(T x) {                                                               \
@@ -150,7 +152,19 @@ namespace Experimental {
     return FUNC(static_cast<double>(x));                                    \
   }
 
-#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC)                               \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC)        \
+  KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC)        \
+  KOKKOS_INLINE_FUNCTION bool FUNC(long double x) {   \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
+    return FUNC(x);                                   \
+  }
+#else
+#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
+  KOKKOS_IMPL_MATH_UNARY_PREDICATE_OTHER(FUNC)
+#endif
+
+#define KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC)                         \
   KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) {                      \
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \
     return FUNC(x, y);                                                       \
@@ -181,57 +195,29 @@ namespace Experimental {
     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y));         \
   }
 
-#else  // long double overloads are not available
-
-#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC)                                 \
-  KOKKOS_INLINE_FUNCTION float FUNC(float x) {                                \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(x);                                                           \
-  }                                                                           \
-  KOKKOS_INLINE_FUNCTION double FUNC(double x) {                              \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(x);                                                           \
-  }                                                                           \
-  KOKKOS_INLINE_FUNCTION float FUNC##f(float x) {                             \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(x);                                                           \
-  }                                                                           \
-  template <class T>                                                          \
-  KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, double> \
-  FUNC(T x) {                                                                 \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                         \
-    return FUNC(static_cast<double>(x));                                      \
-  }
-
-#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC)                              \
-  KOKKOS_INLINE_FUNCTION bool FUNC(float x) {                               \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
-    return FUNC(x);                                                         \
-  }                                                                         \
-  KOKKOS_INLINE_FUNCTION bool FUNC(double x) {                              \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
-    return FUNC(x);                                                         \
-  }                                                                         \
-  template <class T>                                                        \
-  KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral<T>::value, bool> \
-  FUNC(T x) {                                                               \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \
-    return FUNC(static_cast<double>(x));                                    \
+#if defined(KOKKOS_IMPL_MATH_FUNCTIONS_HAVE_LONG_DOUBLE_OVERLOADS)
+#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC)                               \
+  KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC)                               \
+  KOKKOS_INLINE_FUNCTION long double FUNC(long double x, long double y) {    \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \
+    return FUNC(x, y);                                                       \
+  }                                                                          \
+  KOKKOS_INLINE_FUNCTION long double FUNC##l(long double x, long double y) { \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \
+    return FUNC(x, y);                                                       \
+  }                                                                          \
+  template <class T1, class T2>                                              \
+  KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_arithmetic<T1>::value &&   \
+                                              std::is_arithmetic<T2>::value, \
+                                          Kokkos::Impl::promote_2_t<T1, T2>> \
+  FUNC(T1 x, T2 y) {                                                         \
+    using Promoted = Kokkos::Impl::promote_2_t<T1, T2>;                      \
+    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \
+    return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y));         \
   }
-
+#else
 #define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC)                          \
-  KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) {                 \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                   \
-    return FUNC(x, y);                                                  \
-  }                                                                     \
-  KOKKOS_INLINE_FUNCTION double FUNC(double x, double y) {              \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                   \
-    return FUNC(x, y);                                                  \
-  }                                                                     \
-  KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y) {              \
-    using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                   \
-    return FUNC(x, y);                                                  \
-  }                                                                     \
+  KOKKOS_IMPL_MATH_BINARY_FUNCTION_OTHER(FUNC)                          \
   template <class T1, class T2>                                         \
   KOKKOS_INLINE_FUNCTION std::enable_if_t<                              \
       std::is_arithmetic<T1>::value && std::is_arithmetic<T2>::value && \
@@ -243,7 +229,6 @@ namespace Experimental {
     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                   \
     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y));    \
   }
-
 #endif
 
 // Basic operations

optionally combing all the branches.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah I don't like it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your argument for having the duplications then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's what I was trying to move away from with the refactor. The goal was less macro to reduce complexity and increase readability, even if it means duplicating code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we disagree on this point here then. 🙂

core/src/Kokkos_MathematicalFunctions.hpp Outdated Show resolved Hide resolved
core/src/Kokkos_MathematicalFunctions.hpp Outdated Show resolved Hide resolved
core/src/Kokkos_MathematicalFunctions.hpp Outdated Show resolved Hide resolved
@dalg24
Copy link
Member Author

dalg24 commented Jun 4, 2021

Retest this please

Copy link
Contributor

@lucbv lucbv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes @dalg24 this looks good

@crtrott crtrott merged commit 7695b81 into kokkos:develop Jun 9, 2021
@dalg24 dalg24 deleted the fixup_common_math_functions branch June 9, 2021 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants