From e00afe4b6e3fa33d5fba1b90974f82c352c4cd04 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Thu, 16 Dec 2021 15:39:38 +0300 Subject: [PATCH 1/5] Add sycl::known_identity for sycl::vec --- sycl/include/CL/sycl/known_identity.hpp | 130 ++++++++++++++++--- sycl/test/basic_tests/known_identity.cpp | 152 ++++++++++++++++++++++- 2 files changed, 263 insertions(+), 19 deletions(-) diff --git a/sycl/include/CL/sycl/known_identity.hpp b/sycl/include/CL/sycl/known_identity.hpp index 955d7b8a9cfca..55511bf6e44f2 100644 --- a/sycl/include/CL/sycl/known_identity.hpp +++ b/sycl/include/CL/sycl/known_identity.hpp @@ -65,33 +65,34 @@ using IsLogicalOR = // Identity = 0 template -using IsZeroIdentityOp = bool_constant< - (is_sgeninteger::value && - (IsPlus::value || IsBitOR::value || - IsBitXOR::value)) || - (is_sgenfloat::value && IsPlus::value)>; +using IsZeroIdentityOp = + bool_constant<(is_geninteger::value && + (IsPlus::value || + IsBitOR::value || + IsBitXOR::value)) || + (is_genfloat::value && IsPlus::value)>; // Identity = 1 template using IsOneIdentityOp = - bool_constant<(is_sgeninteger::value || is_sgenfloat::value) && + bool_constant<(is_geninteger::value || is_genfloat::value) && IsMultiplies::value>; // Identity = ~0 template -using IsOnesIdentityOp = bool_constant::value && +using IsOnesIdentityOp = bool_constant::value && IsBitAND::value>; // Identity = template using IsMinimumIdentityOp = - bool_constant<(is_sgeninteger::value || is_sgenfloat::value) && + bool_constant<(is_geninteger::value || is_genfloat::value) && IsMinimum::value>; // Identity = template using IsMaximumIdentityOp = - bool_constant<(is_sgeninteger::value || is_sgenfloat::value) && + bool_constant<(is_geninteger::value || is_genfloat::value) && IsMaximum::value>; // Identity = false @@ -125,7 +126,27 @@ template struct known_identity_impl< BinaryOperation, AccumulatorT, std::enable_if_t::value>> { - static constexpr AccumulatorT value = 0; + static constexpr AccumulatorT value = static_cast(0); +}; + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + vec(std::byte(0)); +}; +#endif + +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + vec(sycl::half()); }; template @@ -145,8 +166,19 @@ template struct known_identity_impl< BinaryOperation, AccumulatorT, std::enable_if_t::value>> { - static constexpr AccumulatorT value = 1; + static constexpr AccumulatorT value = static_cast(1); +}; + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t< + IsOneIdentityOp, BinaryOperation>::value>> { + static constexpr vec value = + vec(std::byte(1)); }; +#endif template struct known_identity_impl< @@ -165,48 +197,110 @@ template struct known_identity_impl< BinaryOperation, AccumulatorT, std::enable_if_t::value>> { - static constexpr AccumulatorT value = ~static_cast(0); + static constexpr AccumulatorT value = static_cast(-1LL); +}; + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + vec(std::byte(-1LL)); }; +#endif /// Returns maximal possible value as identity for MIN operations. template struct known_identity_impl::value>> { - static constexpr AccumulatorT value = + static constexpr AccumulatorT value = static_cast( std::numeric_limits::has_infinity ? std::numeric_limits::infinity() - : (std::numeric_limits::max)(); + : (std::numeric_limits::max)()); }; +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + static_cast>( + std::numeric_limits>::has_infinity + ? std::numeric_limits>::infinity() + : (std::numeric_limits>::max)()); +}; +#endif + /// Returns minimal possible value as identity for MAX operations. template struct known_identity_impl::value>> { - static constexpr AccumulatorT value = + static constexpr AccumulatorT value = static_cast( std::numeric_limits::has_infinity ? static_cast( -std::numeric_limits::infinity()) - : std::numeric_limits::lowest(); + : std::numeric_limits::lowest()); +}; + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = static_cast< + vec>( + std::numeric_limits>::has_infinity + ? static_cast>( + -std::numeric_limits>::infinity()) + : std::numeric_limits>::lowest()); }; +#endif /// Returns false as identity for LOGICAL OR operations. template struct known_identity_impl< BinaryOperation, AccumulatorT, std::enable_if_t::value>> { - static constexpr AccumulatorT value = false; + static constexpr AccumulatorT value = static_cast(false); }; +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + vec(std::byte(false)); +}; +#endif + /// Returns true as identity for LOGICAL AND operations. template struct known_identity_impl< BinaryOperation, AccumulatorT, std::enable_if_t::value>> { - static constexpr AccumulatorT value = true; + static constexpr AccumulatorT value = static_cast(true); }; +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template +struct known_identity_impl< + BinaryOperation, vec, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr vec value = + vec(std::byte(true)); +}; +#endif + } // namespace detail // ---- has_known_identity diff --git a/sycl/test/basic_tests/known_identity.cpp b/sycl/test/basic_tests/known_identity.cpp index 06908d1ca247d..a758e7dd7f583 100644 --- a/sycl/test/basic_tests/known_identity.cpp +++ b/sycl/test/basic_tests/known_identity.cpp @@ -1,4 +1,5 @@ -// RUN: %clangxx -fsycl -fsyntax-only -Xclang -verify %s -Xclang -verify-ignore-unexpected=note,warning +// RUN: %clangxx -fsycl -Xclang -verify %s -Xclang -verify-ignore-unexpected=note,warning -o %t.out -std=c++17 +// RUN: %RUN_ON_HOST %t.out // expected-no-diagnostics // This test performs basic checks of has_known_identity and known_identity @@ -6,6 +7,7 @@ #include #include +#include using namespace cl::sycl; @@ -93,6 +95,152 @@ template void checkBoolKnownIdentity() { static_assert(known_identity, T>::value == false); } +template +bool compareVectors(const vec a, const vec b) { + bool res = true; + for (int i = 0; i < Num; ++i) { + res &= (a[i] == b[i]); + } + if (!res) { + for (int i = 0; i < Num; ++i) { + std::cout << "(" << (int)a[i] << " == " << (int)b[i] << ")" << std::endl; + } + } + return res; +} + +template +typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value, + void>::type +checkVecKnownIdentity() { + constexpr vec zeros(T(0)); + constexpr vec ones(T(1)); + constexpr vec bit_ones(~T(0)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + assert(compareVectors(known_identity, vec>::value, zeros)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + assert(compareVectors(known_identity, vec>::value, zeros)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + assert(compareVectors(known_identity, vec>::value, zeros)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + assert( + compareVectors(known_identity, vec>::value, bit_ones)); + + static_assert(has_known_identity, vec>::value); + static_assert( + has_known_identity>, vec>::value); + assert( + compareVectors(known_identity, vec>::value, zeros)); + + static_assert(has_known_identity, vec>::value); + static_assert( + has_known_identity>, vec>::value); + assert( + compareVectors(known_identity, vec>::value, ones)); + + static_assert(has_known_identity, vec>::value); + static_assert( + has_known_identity>, vec>::value); + assert( + compareVectors(known_identity, vec>::value, ones)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + if constexpr (!std::is_same::value) { + constexpr vec maxs(-std::numeric_limits::infinity()); + assert(compareVectors(known_identity, vec>::value, maxs)); + } + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + if constexpr (!std::is_same::value) { + constexpr vec mins(std::numeric_limits::infinity()); + assert(compareVectors(known_identity, vec>::value, mins)); + } +} + +template +typename std::enable_if::value || + std::is_same::value || + std::is_same::value, + void>::type +checkVecKnownIdentity() { + constexpr vec zeros(T(0.0f)); + constexpr vec ones(T(1.0f)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + assert(compareVectors(known_identity, vec>::value, zeros)); + + static_assert(has_known_identity, vec>::value); + static_assert( + has_known_identity>, vec>::value); + assert( + compareVectors(known_identity, vec>::value, ones)); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); + + static_assert(has_known_identity, vec>::value); + static_assert(has_known_identity>, vec>::value); +} + +void checkVecTypesKnownIdentity() { + +#define CHECK_VEC(type) \ + do { \ + checkVecKnownIdentity(); \ + checkVecKnownIdentity(); \ + checkVecKnownIdentity(); \ + checkVecKnownIdentity(); \ + checkVecKnownIdentity(); \ + checkVecKnownIdentity(); \ + } while (0) + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + CHECK_VEC(std::byte); +#endif + CHECK_VEC(int8_t); + CHECK_VEC(int16_t); + CHECK_VEC(int32_t); + CHECK_VEC(int64_t); + CHECK_VEC(uint8_t); + CHECK_VEC(uint16_t); + CHECK_VEC(uint32_t); + CHECK_VEC(uint64_t); + + CHECK_VEC(char); + CHECK_VEC(short int); + CHECK_VEC(int); + CHECK_VEC(long); + CHECK_VEC(long long); + CHECK_VEC(unsigned char); + CHECK_VEC(unsigned short int); + CHECK_VEC(unsigned int); + CHECK_VEC(unsigned long); + CHECK_VEC(unsigned long long); + CHECK_VEC(float); + CHECK_VEC(double); + + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + +#undef CHECK_VEC +} + int main() { checkIntKnownIdentity(); checkIntKnownIdentity(); @@ -143,6 +291,8 @@ int main() { checkBoolKnownIdentity(); + checkVecTypesKnownIdentity(); + // Few negative tests just to check that it does not always return true. static_assert(!has_known_identity, int>::value); static_assert(!has_known_identity, float>::value); From 78e77a69d5b1392dc805531d506f2742e407e3db Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Thu, 16 Dec 2021 15:53:18 +0300 Subject: [PATCH 2/5] Add sycl::known_identity for sycl::marray --- .../CL/sycl/detail/generic_type_lists.hpp | 210 +++++++++++++++--- sycl/include/CL/sycl/known_identity.hpp | 71 ++++++ sycl/test/basic_tests/known_identity.cpp | 155 +++++++++++++ 3 files changed, 408 insertions(+), 28 deletions(-) diff --git a/sycl/include/CL/sycl/detail/generic_type_lists.hpp b/sycl/include/CL/sycl/detail/generic_type_lists.hpp index a5866d0a391e5..7182e9f79f22e 100644 --- a/sycl/include/CL/sycl/detail/generic_type_lists.hpp +++ b/sycl/include/CL/sycl/detail/generic_type_lists.hpp @@ -21,6 +21,7 @@ __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { template class vec; +template class marray; namespace detail { namespace half_impl { class half; @@ -40,7 +41,12 @@ using scalar_half_list = type_list; using vector_half_list = type_list, vec, vec, vec, vec, vec>; -using half_list = type_list; +using marray_half_list = + type_list, marray, marray, + marray, marray, marray>; + +using half_list = + type_list; using scalar_float_list = type_list; @@ -48,7 +54,12 @@ using vector_float_list = type_list, vec, vec, vec, vec, vec>; -using float_list = type_list; +using marray_float_list = + type_list, marray, marray, + marray, marray, marray>; + +using float_list = + type_list; using scalar_double_list = type_list; @@ -56,7 +67,12 @@ using vector_double_list = type_list, vec, vec, vec, vec, vec>; -using double_list = type_list; +using marray_double_list = + type_list, marray, marray, + marray, marray, marray>; + +using double_list = + type_list; using scalar_floating_list = type_list; @@ -64,7 +80,11 @@ using scalar_floating_list = using vector_floating_list = type_list; -using floating_list = type_list; +using marray_floating_list = + type_list; + +using floating_list = + type_list; // geometric floating point types using scalar_geo_half_list = type_list; @@ -113,8 +133,13 @@ using vector_default_char_list = type_list, vec, vec, vec, vec, vec>; +using marray_default_char_list = + type_list, marray, marray, + marray, marray, marray>; + using default_char_list = - type_list; + type_list; using scalar_signed_char_list = type_list; @@ -122,8 +147,14 @@ using vector_signed_char_list = type_list, vec, vec, vec, vec, vec>; +using marray_signed_char_list = + type_list, marray, + marray, marray, + marray, marray>; + using signed_char_list = - type_list; + type_list; using scalar_unsigned_char_list = type_list; @@ -132,8 +163,14 @@ using vector_unsigned_char_list = vec, vec, vec, vec>; +using marray_unsigned_char_list = + type_list, marray, + marray, marray, + marray, marray>; + using unsigned_char_list = - type_list; + type_list; using scalar_char_list = type_list; +using marray_char_list = + type_list; + using char_list = type_list; // short int types @@ -153,8 +194,14 @@ using vector_signed_short_list = vec, vec, vec>; +using marray_signed_short_list = + type_list, marray, + marray, marray, + marray, marray>; + using signed_short_list = - type_list; + type_list; using scalar_unsigned_short_list = type_list; @@ -163,14 +210,21 @@ using vector_unsigned_short_list = vec, vec, vec, vec>; +using marray_unsigned_short_list = + type_list, marray, + marray, marray, + marray, marray>; + using unsigned_short_list = - type_list; + type_list; using scalar_short_list = type_list; using vector_short_list = - type_list; + type_list; using short_list = type_list; @@ -181,8 +235,14 @@ using vector_signed_int_list = type_list, vec, vec, vec, vec, vec>; +using marray_signed_int_list = + type_list, marray, + marray, marray, + marray, marray>; + using signed_int_list = - type_list; + type_list; using scalar_unsigned_int_list = type_list; @@ -191,8 +251,14 @@ using vector_unsigned_int_list = vec, vec, vec>; +using marray_unsigned_int_list = + type_list, marray, + marray, marray, + marray, marray>; + using unsigned_int_list = - type_list; + type_list; using scalar_int_list = type_list; @@ -200,7 +266,10 @@ using scalar_int_list = using vector_int_list = type_list; -using int_list = type_list; +using marray_int_list = + type_list; + +using int_list = type_list; // long types using scalar_signed_long_list = type_list; @@ -209,8 +278,14 @@ using vector_signed_long_list = type_list, vec, vec, vec, vec, vec>; +using marray_signed_long_list = + type_list, marray, + marray, marray, + marray, marray>; + using signed_long_list = - type_list; + type_list; using scalar_unsigned_long_list = type_list; @@ -219,8 +294,14 @@ using vector_unsigned_long_list = vec, vec, vec, vec>; +using marray_unsigned_long_list = + type_list, marray, + marray, marray, + marray, marray>; + using unsigned_long_list = - type_list; + type_list; using scalar_long_list = type_list; @@ -228,7 +309,11 @@ using scalar_long_list = using vector_long_list = type_list; -using long_list = type_list; +using marray_long_list = + type_list; + +using long_list = + type_list; // long long types using scalar_signed_longlong_list = type_list; @@ -238,8 +323,14 @@ using vector_signed_longlong_list = vec, vec, vec, vec>; +using marray_signed_longlong_list = + type_list, marray, + marray, marray, + marray, marray>; + using signed_longlong_list = - type_list; + type_list; using scalar_unsigned_longlong_list = type_list; @@ -248,8 +339,14 @@ using vector_unsigned_longlong_list = vec, vec, vec, vec>; +using marray_unsigned_longlong_list = + type_list, marray, + marray, marray, + marray, marray>; + using unsigned_longlong_list = - type_list; + type_list; using scalar_longlong_list = type_list; @@ -257,7 +354,11 @@ using scalar_longlong_list = using vector_longlong_list = type_list; -using longlong_list = type_list; +using marray_longlong_list = + type_list; + +using longlong_list = + type_list; // long integer types using scalar_signed_long_integer_list = @@ -266,8 +367,12 @@ using scalar_signed_long_integer_list = using vector_signed_long_integer_list = type_list; +using marray_signed_long_integer_list = + type_list; + using signed_long_integer_list = - type_list; + type_list; using scalar_unsigned_long_integer_list = type_list; @@ -275,8 +380,12 @@ using scalar_unsigned_long_integer_list = using vector_unsigned_long_integer_list = type_list; +using marray_unsigned_long_integer_list = + type_list; + using unsigned_long_integer_list = type_list; + vector_unsigned_long_integer_list, + marray_unsigned_long_integer_list>; using scalar_long_integer_list = type_list; @@ -284,8 +393,12 @@ using scalar_long_integer_list = type_list; +using marray_long_integer_list = type_list; + using long_integer_list = - type_list; + type_list; #if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) // std::byte @@ -294,6 +407,10 @@ using scalar_byte_list = type_list; using vector_byte_list = type_list, vec, vec, vec, vec, vec>; + +using marray_byte_list = type_list, marray, + marray, marray, + marray, marray>; #endif // integer types @@ -311,8 +428,16 @@ using vector_signed_integer_list = type_list< vector_signed_short_list, vector_signed_int_list, vector_signed_long_list, vector_signed_longlong_list>; +using marray_signed_integer_list = type_list< + conditional_t::value, + type_list, + marray_signed_char_list>, + marray_signed_short_list, marray_signed_int_list, marray_signed_long_list, + marray_signed_longlong_list>; + using signed_integer_list = - type_list; + type_list; using scalar_unsigned_integer_list = type_list::value, @@ -340,8 +465,22 @@ using vector_unsigned_integer_list = #endif >; +using marray_unsigned_integer_list = + type_list::value, + type_list, + marray_unsigned_char_list>, + marray_unsigned_short_list, marray_unsigned_int_list, + marray_unsigned_long_list, marray_unsigned_longlong_list +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + , + marray_byte_list +#endif + >; + using unsigned_integer_list = - type_list; + type_list; using scalar_integer_list = type_list; @@ -349,7 +488,11 @@ using scalar_integer_list = using vector_integer_list = type_list; -using integer_list = type_list; +using marray_integer_list = + type_list; + +using integer_list = + type_list; // basic types using scalar_signed_basic_list = @@ -358,15 +501,22 @@ using scalar_signed_basic_list = using vector_signed_basic_list = type_list; +using marray_signed_basic_list = + type_list; + using signed_basic_list = - type_list; + type_list; using scalar_unsigned_basic_list = type_list; using vector_unsigned_basic_list = type_list; +using marray_unsigned_basic_list = type_list; + using unsigned_basic_list = - type_list; + type_list; using scalar_basic_list = type_list; @@ -374,7 +524,11 @@ using scalar_basic_list = using vector_basic_list = type_list; -using basic_list = type_list; +using marray_basic_list = + type_list; + +using basic_list = + type_list; // nan builtin types using nan_list = type_list value = vec(std::byte(0)); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + marray(std::byte(0)); +}; #endif template @@ -178,6 +187,15 @@ struct known_identity_impl< static constexpr vec value = vec(std::byte(1)); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + marray(std::byte(1)); +}; #endif template @@ -209,6 +227,15 @@ struct known_identity_impl< static constexpr vec value = vec(std::byte(-1LL)); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + marray(std::byte(-1LL)); +}; #endif /// Returns maximal possible value as identity for MIN operations. @@ -234,6 +261,18 @@ struct known_identity_impl< ? std::numeric_limits>::infinity() : (std::numeric_limits>::max)()); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + static_cast>( + std::numeric_limits>::has_infinity + ? std::numeric_limits>::infinity() + : (std::numeric_limits>::max)()); +}; #endif /// Returns minimal possible value as identity for MAX operations. @@ -261,6 +300,20 @@ struct known_identity_impl< -std::numeric_limits>::infinity()) : std::numeric_limits>::lowest()); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + static_cast>( + std::numeric_limits>::has_infinity + ? static_cast>( + -std::numeric_limits< + marray>::infinity()) + : std::numeric_limits>::lowest()); +}; #endif /// Returns false as identity for LOGICAL OR operations. @@ -280,6 +333,15 @@ struct known_identity_impl< static constexpr vec value = vec(std::byte(false)); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + marray(std::byte(false)); +}; #endif /// Returns true as identity for LOGICAL AND operations. @@ -299,6 +361,15 @@ struct known_identity_impl< static constexpr vec value = vec(std::byte(true)); }; + +template +struct known_identity_impl< + BinaryOperation, marray, + std::enable_if_t, + BinaryOperation>::value>> { + static constexpr marray value = + marray(std::byte(true)); +}; #endif } // namespace detail diff --git a/sycl/test/basic_tests/known_identity.cpp b/sycl/test/basic_tests/known_identity.cpp index a758e7dd7f583..c6e85b26f9407 100644 --- a/sycl/test/basic_tests/known_identity.cpp +++ b/sycl/test/basic_tests/known_identity.cpp @@ -241,6 +241,160 @@ void checkVecTypesKnownIdentity() { #undef CHECK_VEC } +template +bool compareMarrays(const marray a, const marray b) { + bool res = true; + for (int i = 0; i < Num; ++i) { + res &= (a[i] == b[i]); + } + if (!res) { + for (int i = 0; i < Num; ++i) { + std::cout << "(" << (int)a[i] << " == " << (int)b[i] << ")" << std::endl; + } + } + return res; +} + +template +typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value, + void>::type +checkMarrayKnownIdentity() { + constexpr marray zeros(T(0)); + constexpr marray ones(T(1)); + constexpr marray bit_ones(~T(0)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, zeros)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert( + compareMarrays(known_identity, marray>::value, zeros)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert( + compareMarrays(known_identity, marray>::value, zeros)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, + bit_ones)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, + zeros)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, + ones)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, + ones)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + if constexpr (!std::is_same::value) { + constexpr marray maxs(-std::numeric_limits::infinity()); + assert( + compareMarrays(known_identity, marray>::value, maxs)); + } + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + if constexpr (!std::is_same::value) { + constexpr marray mins(std::numeric_limits::infinity()); + assert( + compareMarrays(known_identity, marray>::value, mins)); + } +} + +template +typename std::enable_if::value || + std::is_same::value || + std::is_same::value, + void>::type +checkMarrayKnownIdentity() { + constexpr marray zeros(T(0.0f)); + constexpr marray ones(T(1.0f)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, zeros)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + assert(compareMarrays(known_identity, marray>::value, + ones)); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); + + static_assert(has_known_identity, marray>::value); + static_assert( + has_known_identity>, marray>::value); +} + +void checkMarrayTypesKnownIdentity() { + +#define CHECK_MARRAY(type) \ + do { \ + checkMarrayKnownIdentity(); \ + checkMarrayKnownIdentity(); \ + checkMarrayKnownIdentity(); \ + checkMarrayKnownIdentity(); \ + checkMarrayKnownIdentity(); \ + checkMarrayKnownIdentity(); \ + } while (0) + +#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + CHECK_MARRAY(std::byte); +#endif + CHECK_MARRAY(int8_t); + CHECK_MARRAY(int16_t); + CHECK_MARRAY(int32_t); + CHECK_MARRAY(int64_t); + CHECK_MARRAY(uint8_t); + CHECK_MARRAY(uint16_t); + CHECK_MARRAY(uint32_t); + CHECK_MARRAY(uint64_t); + + CHECK_MARRAY(char); + CHECK_MARRAY(short int); + CHECK_MARRAY(int); + CHECK_MARRAY(long); + CHECK_MARRAY(long long); + CHECK_MARRAY(unsigned char); + CHECK_MARRAY(unsigned short int); + CHECK_MARRAY(unsigned int); + CHECK_MARRAY(unsigned long); + CHECK_MARRAY(unsigned long long); + CHECK_MARRAY(half); + CHECK_MARRAY(float); + CHECK_MARRAY(double); + +#undef CHECK_MARRAY +} + int main() { checkIntKnownIdentity(); checkIntKnownIdentity(); @@ -292,6 +446,7 @@ int main() { checkBoolKnownIdentity(); checkVecTypesKnownIdentity(); + checkMarrayTypesKnownIdentity(); // Few negative tests just to check that it does not always return true. static_assert(!has_known_identity, int>::value); From d7b908d2ca71e5b02de032b66bd4314e770a2601 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Thu, 16 Dec 2021 16:05:27 +0300 Subject: [PATCH 3/5] Test cleanup --- sycl/test/basic_tests/known_identity.cpp | 144 ++++++++++------------- 1 file changed, 63 insertions(+), 81 deletions(-) diff --git a/sycl/test/basic_tests/known_identity.cpp b/sycl/test/basic_tests/known_identity.cpp index c6e85b26f9407..950d2873bdf74 100644 --- a/sycl/test/basic_tests/known_identity.cpp +++ b/sycl/test/basic_tests/known_identity.cpp @@ -101,11 +101,6 @@ bool compareVectors(const vec a, const vec b) { for (int i = 0; i < Num; ++i) { res &= (a[i] == b[i]); } - if (!res) { - for (int i = 0; i < Num; ++i) { - std::cout << "(" << (int)a[i] << " == " << (int)b[i] << ")" << std::endl; - } - } return res; } @@ -195,50 +190,46 @@ checkVecKnownIdentity() { static_assert(has_known_identity>, vec>::value); } -void checkVecTypesKnownIdentity() { - -#define CHECK_VEC(type) \ - do { \ - checkVecKnownIdentity(); \ - checkVecKnownIdentity(); \ - checkVecKnownIdentity(); \ - checkVecKnownIdentity(); \ - checkVecKnownIdentity(); \ - checkVecKnownIdentity(); \ - } while (0) +template void checkVecTypeKnownIdentity() { + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); + checkVecKnownIdentity(); +} +void checkVecTypesKnownIdentity() { #if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - CHECK_VEC(std::byte); + checkVecTypeKnownIdentity(); #endif - CHECK_VEC(int8_t); - CHECK_VEC(int16_t); - CHECK_VEC(int32_t); - CHECK_VEC(int64_t); - CHECK_VEC(uint8_t); - CHECK_VEC(uint16_t); - CHECK_VEC(uint32_t); - CHECK_VEC(uint64_t); - - CHECK_VEC(char); - CHECK_VEC(short int); - CHECK_VEC(int); - CHECK_VEC(long); - CHECK_VEC(long long); - CHECK_VEC(unsigned char); - CHECK_VEC(unsigned short int); - CHECK_VEC(unsigned int); - CHECK_VEC(unsigned long); - CHECK_VEC(unsigned long long); - CHECK_VEC(float); - CHECK_VEC(double); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); + checkVecTypeKnownIdentity(); checkVecKnownIdentity(); checkVecKnownIdentity(); checkVecKnownIdentity(); checkVecKnownIdentity(); checkVecKnownIdentity(); - -#undef CHECK_VEC } template @@ -247,11 +238,6 @@ bool compareMarrays(const marray a, const marray b) { for (int i = 0; i < Num; ++i) { res &= (a[i] == b[i]); } - if (!res) { - for (int i = 0; i < Num; ++i) { - std::cout << "(" << (int)a[i] << " == " << (int)b[i] << ")" << std::endl; - } - } return res; } @@ -354,45 +340,41 @@ checkMarrayKnownIdentity() { has_known_identity>, marray>::value); } -void checkMarrayTypesKnownIdentity() { - -#define CHECK_MARRAY(type) \ - do { \ - checkMarrayKnownIdentity(); \ - checkMarrayKnownIdentity(); \ - checkMarrayKnownIdentity(); \ - checkMarrayKnownIdentity(); \ - checkMarrayKnownIdentity(); \ - checkMarrayKnownIdentity(); \ - } while (0) +template void checkMarrayTypeKnownIdentity() { + checkMarrayKnownIdentity(); + checkMarrayKnownIdentity(); + checkMarrayKnownIdentity(); + checkMarrayKnownIdentity(); + checkMarrayKnownIdentity(); + checkMarrayKnownIdentity(); +} +void checkMarrayTypesKnownIdentity() { #if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) - CHECK_MARRAY(std::byte); + checkMarrayTypeKnownIdentity(); #endif - CHECK_MARRAY(int8_t); - CHECK_MARRAY(int16_t); - CHECK_MARRAY(int32_t); - CHECK_MARRAY(int64_t); - CHECK_MARRAY(uint8_t); - CHECK_MARRAY(uint16_t); - CHECK_MARRAY(uint32_t); - CHECK_MARRAY(uint64_t); - - CHECK_MARRAY(char); - CHECK_MARRAY(short int); - CHECK_MARRAY(int); - CHECK_MARRAY(long); - CHECK_MARRAY(long long); - CHECK_MARRAY(unsigned char); - CHECK_MARRAY(unsigned short int); - CHECK_MARRAY(unsigned int); - CHECK_MARRAY(unsigned long); - CHECK_MARRAY(unsigned long long); - CHECK_MARRAY(half); - CHECK_MARRAY(float); - CHECK_MARRAY(double); - -#undef CHECK_MARRAY + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); + checkMarrayTypeKnownIdentity(); } int main() { From 124f9476454f0042b23bd11ad18494762edd13ce Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Fri, 17 Dec 2021 18:29:14 +0300 Subject: [PATCH 4/5] Update sycl/test/basic_tests/known_identity.cpp Co-authored-by: Steffen Larsen --- sycl/test/basic_tests/known_identity.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sycl/test/basic_tests/known_identity.cpp b/sycl/test/basic_tests/known_identity.cpp index 950d2873bdf74..75a3d35464877 100644 --- a/sycl/test/basic_tests/known_identity.cpp +++ b/sycl/test/basic_tests/known_identity.cpp @@ -105,10 +105,8 @@ bool compareVectors(const vec a, const vec b) { } template -typename std::enable_if::value && - !std::is_same::value && - !std::is_same::value, - void>::type +std::enable_if_t && !std::is_same_v && + !std::is_same_v> checkVecKnownIdentity() { constexpr vec zeros(T(0)); constexpr vec ones(T(1)); From 8f73ea6f63b0c38f5e1e37cb30eb13e55b049b36 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Fri, 17 Dec 2021 19:15:45 +0300 Subject: [PATCH 5/5] Replace std::is_same with std::is_same_v --- sycl/test/basic_tests/known_identity.cpp | 25 ++++++++++++------------ 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sycl/test/basic_tests/known_identity.cpp b/sycl/test/basic_tests/known_identity.cpp index 75a3d35464877..2b4f9b92c9158 100644 --- a/sycl/test/basic_tests/known_identity.cpp +++ b/sycl/test/basic_tests/known_identity.cpp @@ -149,23 +149,23 @@ checkVecKnownIdentity() { static_assert(has_known_identity, vec>::value); static_assert(has_known_identity>, vec>::value); - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { constexpr vec maxs(-std::numeric_limits::infinity()); assert(compareVectors(known_identity, vec>::value, maxs)); } static_assert(has_known_identity, vec>::value); static_assert(has_known_identity>, vec>::value); - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { constexpr vec mins(std::numeric_limits::infinity()); assert(compareVectors(known_identity, vec>::value, mins)); } } template -typename std::enable_if::value || - std::is_same::value || - std::is_same::value, +typename std::enable_if || + std::is_same_v || + std::is_same_v, void>::type checkVecKnownIdentity() { constexpr vec zeros(T(0.0f)); @@ -240,9 +240,8 @@ bool compareMarrays(const marray a, const marray b) { } template -typename std::enable_if::value && - !std::is_same::value && - !std::is_same::value, +typename std::enable_if && !std::is_same_v && + !std::is_same_v, void>::type checkMarrayKnownIdentity() { constexpr marray zeros(T(0)); @@ -293,7 +292,7 @@ checkMarrayKnownIdentity() { static_assert(has_known_identity, marray>::value); static_assert( has_known_identity>, marray>::value); - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { constexpr marray maxs(-std::numeric_limits::infinity()); assert( compareMarrays(known_identity, marray>::value, maxs)); @@ -302,7 +301,7 @@ checkMarrayKnownIdentity() { static_assert(has_known_identity, marray>::value); static_assert( has_known_identity>, marray>::value); - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { constexpr marray mins(std::numeric_limits::infinity()); assert( compareMarrays(known_identity, marray>::value, mins)); @@ -310,9 +309,9 @@ checkMarrayKnownIdentity() { } template -typename std::enable_if::value || - std::is_same::value || - std::is_same::value, +typename std::enable_if || + std::is_same_v || + std::is_same_v, void>::type checkMarrayKnownIdentity() { constexpr marray zeros(T(0.0f));