diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index b0b8a93d6697c..b7842a8fd99e4 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -118,6 +118,37 @@ class vectorized_binary { } }; +// Vectorized_binary for logical operations +template +class vectorized_binary< + VecT, BinaryOperation, + std::enable_if_t()( + std::declval(), + std::declval()))>>> { +public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + unsigned result = 0; + constexpr size_t elem_size = 8 * sizeof(typename VecT::element_type); + static_assert(elem_size < 32, + "Vector element size must be less than 4 bytes"); + constexpr unsigned bool_mask = (1U << elem_size) - 1; + + for (size_t i = 0; i < a.size(); ++i) { + bool comp_result = binary_op(a[i], b[i]); + result |= (comp_result ? bool_mask : 0U) << (i * elem_size); + } + + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) { + v4[i] = static_cast( + (result >> (i * elem_size)) & bool_mask); + } + + return v4; + } +}; + /// Extend the 'val' to 'bit' size, zero extend for unsigned int and signed /// extend for signed int. Returns a signed integer type. template @@ -1040,7 +1071,7 @@ struct average { } // namespace detail -/// Compute vectorized binary operation value for two values, with each value +/// Compute vectorized binary operation value for two/four values, with each /// treated as a vector type \p VecT. /// \tparam [in] VecT The type of the vector /// \tparam [in] BinaryOperation The binary operation class @@ -1052,14 +1083,19 @@ struct average { template inline unsigned vectorized_binary(unsigned a, unsigned b, const BinaryOperation binary_op, - bool need_relu = false) { + [[maybe_unused]] bool need_relu = false) { sycl::vec v0{a}, v1{b}; auto v2 = v0.as(); auto v3 = v1.as(); auto v4 = detail::vectorized_binary()(v2, v3, binary_op); - if (need_relu) - v4 = relu(v4); + if constexpr (!std::is_same_v< + bool, decltype(std::declval()( + std::declval(), + std::declval()))>) { + if (need_relu) + v4 = relu(v4); + } v0 = v4.template as>(); return v0; } diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index 9c57c88ce445b..630d4b9c9f154 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -48,6 +48,18 @@ void test_vectorized_binary(unsigned op1, unsigned op2, unsigned expected, op1, op2, expected, need_relu); } +template +void test_vectorized_binary_logical(unsigned op1, unsigned op2, + unsigned expected) { + std::cout << __PRETTY_FUNCTION__ << std::endl; + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op2, expected, false); +} + template void vectorized_unary_kernel(unsigned *a, unsigned *r) { *r = syclcompat::vectorized_unary(*a, UnaryOp()); @@ -203,5 +215,67 @@ int main() { test_vectorized_binary_with_pred( 0x80010002, 0x00040002, 0x00040002, false, true); + // Logical Binary Operators v2 + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00002, 0xFFF00001, 0xFFFF0000); + test_vectorized_binary_logical, sycl::short2>( + 0x0001F00F, 0x0003F00F, 0x0000FFFF); + + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00002, 0xFFF00001, 0x0000FFFF); + test_vectorized_binary_logical, sycl::short2>( + 0x0001F00F, 0x0003F00F, 0xFFFF0000); + + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00002, 0xFFF00001, 0xFFFFFFFF); + test_vectorized_binary_logical, sycl::short2>( + 0x0001F00F, 0x0003F001, 0x0000FFFF); + + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00002, 0xFFF00001, 0x0000FFFF); + test_vectorized_binary_logical, sycl::short2>( + 0x0003F00F, 0x0001F00F, 0xFFFF0000); + + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00001, 0xF0F00002, 0x0000FFFF); + test_vectorized_binary_logical, sycl::short2>( + 0x0001FF0F, 0x0003F00F, 0xFFFF0000); + + test_vectorized_binary_logical, sycl::short2>( + 0xFFF00001, 0xFFF00002, 0x0000FFFF); + test_vectorized_binary_logical, sycl::short2>( + 0x0001F00F, 0x0003F00F, 0xFFFF0000); + + // Logical Binary Operators v4 + test_vectorized_binary_logical, sycl::uchar4>( + 0x0001F00F, 0x0003F00F, 0xFF00FFFF); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0102F0F0, 0x0202F0FF, 0x00FFFF00); + + test_vectorized_binary_logical, sycl::uchar4>( + 0x0001F00F, 0xFF01F10F, 0xFF00FF00); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0201F0F0, 0x0202F0FF, 0x00FF00FF); + + test_vectorized_binary_logical, sycl::uchar4>( + 0xFFF00002, 0xFFF10101, 0xFF0000FF); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0001F1F0, 0x0103F001, 0x0000FFFF); + + test_vectorized_binary_logical, sycl::uchar4>( + 0xFFF00002, 0xF0F00001, 0xFF0000FF); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0103F0F1, 0x0102F0F0, 0x00FF00FF); + + test_vectorized_binary_logical, sycl::uchar4>( + 0xFFF10001, 0xFFF00100, 0xFF00FF00); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0101F1F0, 0x0003F0F1, 0x00FF00FF); + + test_vectorized_binary_logical, sycl::uchar4>( + 0xFFF10001, 0xFFF20100, 0x00FFFF00); + test_vectorized_binary_logical, sycl::uchar4>( + 0x0101F1F0, 0x0102F1F1, 0x00FF00FF); + return 0; }