From fb65bab86373821a1a04b038f248de775701bd0c Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Thu, 7 Sep 2023 11:05:18 +0800 Subject: [PATCH] [SYCLomatic] Add migration of 59 bf16 APIs and refine 2 half APIs. Signed-off-by: Tang, Jiajun jiajun.tang@intel.com --- clang/lib/DPCT/APINamesMath.inc | 4 +- clang/lib/DPCT/APINamesMathRewrite.inc | 766 ++++++++++++------ clang/runtime/dpct-rt/include/math.hpp | 46 ++ clang/test/dpct/math/bfloat16/bfloat16.cu | 106 +++ .../math/bfloat16/bfloat16_cuda12_after.cu | 41 + .../math/bfloat16/bfloat16_experimental.cu | 106 +++ clang/test/dpct/math/bfloat16/bfloat16_ext.cu | 8 + clang/test/dpct/math/cuda-math-extension.cu | 4 + 8 files changed, 809 insertions(+), 272 deletions(-) create mode 100644 clang/test/dpct/math/bfloat16/bfloat16_cuda12_after.cu diff --git a/clang/lib/DPCT/APINamesMath.inc b/clang/lib/DPCT/APINamesMath.inc index 31167e2ffd12..85967db36ed2 100644 --- a/clang/lib/DPCT/APINamesMath.inc +++ b/clang/lib/DPCT/APINamesMath.inc @@ -12,8 +12,8 @@ ENTRY_REWRITE("__hfma") ENTRY_REWRITE("__hfma2") // Half Comparison Functions -ENTRY_RENAMED("__hisinf", MapNames::getClNamespace(false, true) + "isinf") -ENTRY_RENAMED("__hisnan", MapNames::getClNamespace(false, true) + "isnan") +ENTRY_REWRITE("__hisinf") +ENTRY_REWRITE("__hisnan") // Half Math Functions ENTRY_REWRITE("hceil") diff --git a/clang/lib/DPCT/APINamesMathRewrite.inc b/clang/lib/DPCT/APINamesMathRewrite.inc index f49648e726f4..82878b76f355 100644 --- a/clang/lib/DPCT/APINamesMathRewrite.inc +++ b/clang/lib/DPCT/APINamesMathRewrite.inc @@ -749,11 +749,13 @@ MATH_API_REWRITER_DEVICE( "__heq", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__heq"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__heq", CALL(MapNames::getClNamespace() + - "ext::intel::math::heq", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__heq", CALL(MapNames::getClNamespace() + + "ext::intel::math::heq", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__heq"), BINARY_OP_FACTORY_ENTRY("__heq", BinaryOperatorKind::BO_EQ, makeCallArgCreatorWithCall(0), @@ -763,11 +765,13 @@ MATH_API_REWRITER_DEVICE( "__hequ", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hequ"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hequ", CALL(MapNames::getClNamespace() + - "ext::intel::math::hequ", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hequ", CALL(MapNames::getClNamespace() + + "ext::intel::math::hequ", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hequ"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -780,11 +784,13 @@ MATH_API_REWRITER_DEVICE( "__hge", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hge"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hge", CALL(MapNames::getClNamespace() + - "ext::intel::math::hge", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hge", CALL(MapNames::getClNamespace() + + "ext::intel::math::hge", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hge"), BINARY_OP_FACTORY_ENTRY("__hge", BinaryOperatorKind::BO_GE, makeCallArgCreatorWithCall(0), @@ -794,11 +800,13 @@ MATH_API_REWRITER_DEVICE( "__hgeu", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgeu"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgeu", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgeu", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgeu", CALL(MapNames::getClNamespace() + + "ext::intel::math::hgeu", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgeu"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -811,11 +819,13 @@ MATH_API_REWRITER_DEVICE( "__hgt", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgt"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgt", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgt", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgt", CALL(MapNames::getClNamespace() + + "ext::intel::math::hgt", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgt"), BINARY_OP_FACTORY_ENTRY("__hgt", BinaryOperatorKind::BO_GT, makeCallArgCreatorWithCall(0), @@ -825,11 +835,13 @@ MATH_API_REWRITER_DEVICE( "__hgtu", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgtu"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgtu", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgtu", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgtu", CALL(MapNames::getClNamespace() + + "ext::intel::math::hgtu", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgtu"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -838,15 +850,67 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::greater<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hisinf", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hisinf"), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hisinf", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hisinf", + ARG(0))))), + EMPTY_FACTORY_ENTRY("__hisinf"), + CONDITIONAL_FACTORY_ENTRY( + CheckArgType(0, "__half"), + CALL_FACTORY_ENTRY( + "__hisinf", + CALL(MapNames::getClNamespace(false, true) + "isinf", ARG(0))), + CALL_FACTORY_ENTRY("__hisinf", + CALL(MapNames::getClNamespace(false, true) + + "isinf", + CALL("float", ARG(0))))))) + +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half"), + MATH_API_REWRITER_DEVICE( + "__hisnan", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hisnan"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hisnan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hisnan", + ARG(0)))), + EMPTY_FACTORY_ENTRY("__hisnan"), + CALL_FACTORY_ENTRY("__hisnan", + CALL(MapNames::getClNamespace(false, true) + + "isnan", + ARG(0))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hisnan", + CALL_FACTORY_ENTRY("__hisnan", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::isnan", + ARG(0))), + CALL_FACTORY_ENTRY("__hisnan", + CALL(MapNames::getClNamespace(false, true) + "isnan", + CALL("float", ARG(0)))))) + MATH_API_REWRITER_DEVICE( "__hle", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hle"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hle", CALL(MapNames::getClNamespace() + - "ext::intel::math::hle", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hle", CALL(MapNames::getClNamespace() + + "ext::intel::math::hle", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hle"), BINARY_OP_FACTORY_ENTRY("__hle", BinaryOperatorKind::BO_LE, makeCallArgCreatorWithCall(0), @@ -856,11 +920,13 @@ MATH_API_REWRITER_DEVICE( "__hleu", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hleu"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hleu", CALL(MapNames::getClNamespace() + - "ext::intel::math::hleu", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hleu", CALL(MapNames::getClNamespace() + + "ext::intel::math::hleu", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hleu"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -873,11 +939,13 @@ MATH_API_REWRITER_DEVICE( "__hlt", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hlt"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hlt", CALL(MapNames::getClNamespace() + - "ext::intel::math::hlt", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hlt", CALL(MapNames::getClNamespace() + + "ext::intel::math::hlt", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hlt"), BINARY_OP_FACTORY_ENTRY("__hlt", BinaryOperatorKind::BO_LT, makeCallArgCreatorWithCall(0), @@ -887,11 +955,13 @@ MATH_API_REWRITER_DEVICE( "__hltu", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hltu"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hltu", CALL(MapNames::getClNamespace() + - "ext::intel::math::hltu", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hltu", CALL(MapNames::getClNamespace() + + "ext::intel::math::hltu", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hltu"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -900,57 +970,85 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::less<>()")))))) -MATH_API_REWRITER_DEVICE( - "__hmax", - MATH_API_DEVICE_NODES( - EMPTY_FACTORY_ENTRY("__hmax"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmax", CALL(MapNames::getClNamespace() + - "ext::intel::math::hmax", - ARG(0), ARG(1)))), - EMPTY_FACTORY_ENTRY("__hmax"), - CALL_FACTORY_ENTRY("__hmax", CALL(MapNames::getClNamespace() + "fmax", - ARG(0), ARG(1))))) +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half"), + MATH_API_REWRITER_DEVICE( + "__hmax", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax"), + CALL_FACTORY_ENTRY("__hmax", + CALL(MapNames::getClNamespace() + "fmax", ARG(0), + ARG(1))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hmax", + CALL_FACTORY_ENTRY("__hmax", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::fmax", + ARG(0), ARG(1))), + CALL_FACTORY_ENTRY("__hmax", + CALL(MapNames::getClNamespace(false, true) + "fmax", + CALL("float", ARG(0)), CALL("float", ARG(1)))))) MATH_API_REWRITER_DEVICE( "__hmax_nan", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hmax_nan"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmax_nan", - CALL(MapNames::getClNamespace() + - "ext::intel::math::hmax_nan", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax_nan", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hmax_nan"), CALL_FACTORY_ENTRY("__hmax_nan", CALL(MapNames::getDpctNamespace() + "fmax_nan", ARG(0), ARG(1))))) -MATH_API_REWRITER_DEVICE( - "__hmin", - MATH_API_DEVICE_NODES( - EMPTY_FACTORY_ENTRY("__hmin"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmin", CALL(MapNames::getClNamespace() + - "ext::intel::math::hmin", - ARG(0), ARG(1)))), - EMPTY_FACTORY_ENTRY("__hmin"), - CALL_FACTORY_ENTRY("__hmin", CALL(MapNames::getClNamespace() + "fmin", - ARG(0), ARG(1))))) +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half"), + MATH_API_REWRITER_DEVICE( + "__hmin", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin"), + CALL_FACTORY_ENTRY("__hmin", + CALL(MapNames::getClNamespace() + "fmin", ARG(0), + ARG(1))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hmin", + CALL_FACTORY_ENTRY("__hmin", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::fmin", + ARG(0), ARG(1))), + CALL_FACTORY_ENTRY("__hmin", + CALL(MapNames::getClNamespace(false, true) + "fmin", + CALL("float", ARG(0)), CALL("float", ARG(1)))))) MATH_API_REWRITER_DEVICE( "__hmin_nan", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hmin_nan"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmin_nan", - CALL(MapNames::getClNamespace() + - "ext::intel::math::hmin_nan", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin_nan", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hmin_nan"), CALL_FACTORY_ENTRY("__hmin_nan", CALL(MapNames::getDpctNamespace() + "fmin_nan", @@ -960,11 +1058,13 @@ MATH_API_REWRITER_DEVICE( "__hne", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hne"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hne", CALL(MapNames::getClNamespace() + - "ext::intel::math::hne", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hne", CALL(MapNames::getClNamespace() + + "ext::intel::math::hne", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hne"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -977,11 +1077,13 @@ MATH_API_REWRITER_DEVICE( "__hneu", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hneu"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hneu", CALL(MapNames::getClNamespace() + - "ext::intel::math::hneu", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hneu", CALL(MapNames::getClNamespace() + + "ext::intel::math::hneu", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hneu"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -995,11 +1097,14 @@ MATH_API_REWRITER_DEVICE( "__hbeq2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbeq2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbeq2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbeq2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbeq2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbeq2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbeq2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1012,11 +1117,14 @@ MATH_API_REWRITER_DEVICE( "__hbequ2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbequ2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbequ2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbequ2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbequ2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbequ2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbequ2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1029,11 +1137,14 @@ MATH_API_REWRITER_DEVICE( "__hbge2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbge2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbge2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbge2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbge2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbge2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbge2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1046,11 +1157,14 @@ MATH_API_REWRITER_DEVICE( "__hbgeu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbgeu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbgeu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbgeu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbgeu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbgeu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbgeu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1063,11 +1177,14 @@ MATH_API_REWRITER_DEVICE( "__hbgt2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbgt2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbgt2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbgt2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbgt2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbgt2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbgt2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1079,11 +1196,14 @@ MATH_API_REWRITER_DEVICE( "__hbgtu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbgtu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbgtu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbgtu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbgtu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbgtu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbgtu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1096,11 +1216,14 @@ MATH_API_REWRITER_DEVICE( "__hble2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hble2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hble2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hble2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hble2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hble2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hble2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1113,11 +1236,14 @@ MATH_API_REWRITER_DEVICE( "__hbleu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbleu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbleu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbleu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbleu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbleu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbleu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1130,11 +1256,14 @@ MATH_API_REWRITER_DEVICE( "__hblt2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hblt2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hblt2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hblt2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hblt2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hblt2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hblt2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1146,11 +1275,14 @@ MATH_API_REWRITER_DEVICE( "__hbltu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbltu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbltu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbltu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbltu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbltu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbltu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1163,11 +1295,14 @@ MATH_API_REWRITER_DEVICE( "__hbne2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbne2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbne2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbne2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbne2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbne2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbne2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1180,11 +1315,14 @@ MATH_API_REWRITER_DEVICE( "__hbneu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hbneu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hbneu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hbneu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hbneu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hbneu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hbneu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1197,11 +1335,13 @@ MATH_API_REWRITER_DEVICE( "__heq2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__heq2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__heq2", CALL(MapNames::getClNamespace() + - "ext::intel::math::heq2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__heq2", CALL(MapNames::getClNamespace() + + "ext::intel::math::heq2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__heq2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1222,11 +1362,14 @@ MATH_API_REWRITER_DEVICE( "__hequ2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hequ2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hequ2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hequ2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hequ2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hequ2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hequ2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1250,11 +1393,13 @@ MATH_API_REWRITER_DEVICE( "__hge2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hge2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hge2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hge2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hge2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hge2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hge2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1276,11 +1421,14 @@ MATH_API_REWRITER_DEVICE( "__hgeu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgeu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgeu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgeu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgeu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hgeu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgeu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1304,11 +1452,13 @@ MATH_API_REWRITER_DEVICE( "__hgt2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgt2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgt2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgt2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgt2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hgt2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgt2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1329,11 +1479,14 @@ MATH_API_REWRITER_DEVICE( "__hgtu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hgtu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hgtu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hgtu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hgtu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hgtu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hgtu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1353,32 +1506,50 @@ MATH_API_REWRITER_DEVICE( "unordered_compare_mask", ARG(0), ARG(1), LITERAL("std::greater<>()"))))) -MATH_API_REWRITER_DEVICE( - "__hisnan2", - MATH_API_DEVICE_NODES( - EMPTY_FACTORY_ENTRY("__hisnan2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hisnan2", - CALL(MapNames::getClNamespace() + - "ext::intel::math::hisnan2", - ARG(0)))), - EMPTY_FACTORY_ENTRY("__hisnan2"), - FEATURE_REQUEST_FACTORY( - HelperFeatureEnum::device_ext, - CALL_FACTORY_ENTRY("__hisnan2", - CALL(MapNames::getDpctNamespace() + "isnan", - ARG(0)))))) +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half2"), + MATH_API_REWRITER_DEVICE( + "__hisnan2", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hisnan2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hisnan2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hisnan2", + ARG(0)))), + EMPTY_FACTORY_ENTRY("__hisnan2"), + FEATURE_REQUEST_FACTORY( + HelperFeatureEnum::device_ext, + CALL_FACTORY_ENTRY("__hisnan2", + CALL(MapNames::getDpctNamespace() + "isnan", + ARG(0)))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hisnan2", + CALL_FACTORY_ENTRY("__hisnan2", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::isnan", + ARG(0))), + CALL_FACTORY_ENTRY( + "__hisnan2", + CALL(MapNames::getClNamespace() + "marray<" + + MapNames::getClNamespace() + "ext::oneapi::bfloat16, 2>", + CALL(MapNames::getClNamespace() + "isnan", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")))), + CALL(MapNames::getClNamespace() + "isnan", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("1")))))))) MATH_API_REWRITER_DEVICE( "__hle2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hle2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hle2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hle2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hle2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hle2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hle2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1400,11 +1571,14 @@ MATH_API_REWRITER_DEVICE( "__hleu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hleu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hleu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hleu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hleu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hleu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hleu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1428,11 +1602,13 @@ MATH_API_REWRITER_DEVICE( "__hlt2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hlt2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hlt2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hlt2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hlt2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hlt2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hlt2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1453,11 +1629,14 @@ MATH_API_REWRITER_DEVICE( "__hltu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hltu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hltu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hltu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hltu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hltu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hltu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1477,69 +1656,111 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::less<>()"))))) -MATH_API_REWRITER_DEVICE( - "__hmax2", - MATH_API_DEVICE_NODES( - EMPTY_FACTORY_ENTRY("__hmax2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmax2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hmax2", - ARG(0), ARG(1)))), - EMPTY_FACTORY_ENTRY("__hmax2"), - CALL_FACTORY_ENTRY("__hmax2", - CALL(MapNames::getClNamespace() + "half2", +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half2"), + MATH_API_REWRITER_DEVICE( + "__hmax2", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax2"), + CALL_FACTORY_ENTRY( + "__hmax2", CALL(MapNames::getClNamespace() + "half2", CALL(MapNames::getClNamespace() + "fmax", ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")), ARRAY_SUBSCRIPT(ARG(1), LITERAL("0"))), CALL(MapNames::getClNamespace() + "fmax", ARRAY_SUBSCRIPT(ARG(0), LITERAL("1")), - ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))) + ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hmax2", + CALL_FACTORY_ENTRY("__hmax2", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::fmax", + ARG(0), ARG(1))), + CALL_FACTORY_ENTRY( + "__hmax2", + CALL(MapNames::getClNamespace() + "marray<" + + MapNames::getClNamespace() + "ext::oneapi::bfloat16, 2>", + CALL(MapNames::getClNamespace() + "fmax", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("0"))), + CALL("float", ARRAY_SUBSCRIPT(ARG(1), LITERAL("0")))), + CALL(MapNames::getClNamespace() + "fmax", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("1"))), + CALL("float", ARRAY_SUBSCRIPT(ARG(1), LITERAL("1")))))))) MATH_API_REWRITER_DEVICE( "__hmax2_nan", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hmax2_nan"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmax2_nan", - CALL(MapNames::getClNamespace() + - "ext::intel::math::hmax2_nan", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax2_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax2_nan", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hmax2_nan"), CALL_FACTORY_ENTRY("__hmax2_nan", CALL(MapNames::getDpctNamespace() + "fmax_nan", ARG(0), ARG(1))))) -MATH_API_REWRITER_DEVICE( - "__hmin2", - MATH_API_DEVICE_NODES( - EMPTY_FACTORY_ENTRY("__hmin2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmin2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hmin2", - ARG(0), ARG(1)))), - EMPTY_FACTORY_ENTRY("__hmin2"), - CALL_FACTORY_ENTRY("__hmin2", - CALL(MapNames::getClNamespace() + "half2", +MATH_API_REWRITER_DEVICE_OVERLOAD( + CheckArgType(0, "__half2"), + MATH_API_REWRITER_DEVICE( + "__hmin2", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin2"), + CALL_FACTORY_ENTRY( + "__hmin2", CALL(MapNames::getClNamespace() + "half2", CALL(MapNames::getClNamespace() + "fmin", ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")), ARRAY_SUBSCRIPT(ARG(1), LITERAL("0"))), CALL(MapNames::getClNamespace() + "fmin", ARRAY_SUBSCRIPT(ARG(0), LITERAL("1")), - ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))) + ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))), + MATH_API_REWRITER_EXPERIMENTAL_BFLOAT16( + "__hmin2", + CALL_FACTORY_ENTRY("__hmin2", + CALL(MapNames::getClNamespace(false, true) + + "ext::oneapi::experimental::fmin", + ARG(0), ARG(1))), + CALL_FACTORY_ENTRY( + "__hmin2", + CALL(MapNames::getClNamespace() + "marray<" + + MapNames::getClNamespace() + "ext::oneapi::bfloat16, 2>", + CALL(MapNames::getClNamespace() + "fmin", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("0"))), + CALL("float", ARRAY_SUBSCRIPT(ARG(1), LITERAL("0")))), + CALL(MapNames::getClNamespace() + "fmin", + CALL("float", ARRAY_SUBSCRIPT(ARG(0), LITERAL("1"))), + CALL("float", ARRAY_SUBSCRIPT(ARG(1), LITERAL("1")))))))) MATH_API_REWRITER_DEVICE( "__hmin2_nan", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hmin2_nan"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hmin2_nan", - CALL(MapNames::getClNamespace() + - "ext::intel::math::hmin2_nan", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin2_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin2_nan", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hmin2_nan"), CALL_FACTORY_ENTRY("__hmin2_nan", CALL(MapNames::getDpctNamespace() + "fmin_nan", @@ -1549,11 +1770,13 @@ MATH_API_REWRITER_DEVICE( "__hne2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hne2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hne2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hne2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hne2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hne2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hne2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, @@ -1575,11 +1798,14 @@ MATH_API_REWRITER_DEVICE( "__hneu2", MATH_API_DEVICE_NODES( EMPTY_FACTORY_ENTRY("__hneu2"), - HEADER_INSERT_FACTORY( - HeaderType::HT_SYCL_Math, - CALL_FACTORY_ENTRY("__hneu2", CALL(MapNames::getClNamespace() + - "ext::intel::math::hneu2", - ARG(0), ARG(1)))), + MATH_API_SPECIFIC_ELSE_EMU( + CheckArgType(0, "__half2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hneu2", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hneu2", + ARG(0), ARG(1))))), EMPTY_FACTORY_ENTRY("__hneu2"), FEATURE_REQUEST_FACTORY( HelperFeatureEnum::device_ext, diff --git a/clang/runtime/dpct-rt/include/math.hpp b/clang/runtime/dpct-rt/include/math.hpp index 0a6bf96e5134..e31a7ed6d4bb 100644 --- a/clang/runtime/dpct-rt/include/math.hpp +++ b/clang/runtime/dpct-rt/include/math.hpp @@ -235,6 +235,14 @@ inline unsigned compare_mask(const sycl::vec a, const sycl::vec b, -compare(a[1], b[1], binary_op)) .as>(); } +template +inline unsigned compare_mask(const sycl::marray a, + const sycl::marray b, + const BinaryOperation binary_op) { + return sycl::vec(-compare(a[0], b[0], binary_op), + -compare(a[1], b[1], binary_op)) + .as>(); +} /// Performs 2 element unordered comparison. /// \param [in] a The first value @@ -263,6 +271,14 @@ inline unsigned unordered_compare_mask(const sycl::vec a, -unordered_compare(a[1], b[1], binary_op)) .as>(); } +template +inline unsigned unordered_compare_mask(const sycl::marray a, + const sycl::marray b, + const BinaryOperation binary_op) { + return sycl::vec(-unordered_compare(a[0], b[0], binary_op), + -unordered_compare(a[1], b[1], binary_op)) + .as>(); +} /// Determine whether 2 element value is NaN. /// \param [in] a The input value @@ -433,11 +449,26 @@ template inline T fmax_nan(const T a, const T b) { return NAN; return sycl::fmax(a, b); } +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> +inline sycl::ext::oneapi::bfloat16 +fmax_nan(const sycl::ext::oneapi::bfloat16 a, + const sycl::ext::oneapi::bfloat16 b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmax(float(a), float(b)); +} +#endif template inline sycl::vec fmax_nan(const sycl::vec a, const sycl::vec b) { return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; } +template +inline sycl::marray fmax_nan(const sycl::marray a, + const sycl::marray b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} /// Performs 2 elements comparison and returns the smaller one. If either of /// inputs is NaN, then return NaN. @@ -449,11 +480,26 @@ template inline T fmin_nan(const T a, const T b) { return NAN; return sycl::fmin(a, b); } +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> +inline sycl::ext::oneapi::bfloat16 +fmin_nan(const sycl::ext::oneapi::bfloat16 a, + const sycl::ext::oneapi::bfloat16 b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmin(float(a), float(b)); +} +#endif template inline sycl::vec fmin_nan(const sycl::vec a, const sycl::vec b) { return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; } +template +inline sycl::marray fmin_nan(const sycl::marray a, + const sycl::marray b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} /// A sycl::abs wrapper functors. struct abs { diff --git a/clang/test/dpct/math/bfloat16/bfloat16.cu b/clang/test/dpct/math/bfloat16/bfloat16.cu index 579116114c94..a60635885ccd 100644 --- a/clang/test/dpct/math/bfloat16/bfloat16.cu +++ b/clang/test/dpct/math/bfloat16/bfloat16.cu @@ -104,6 +104,112 @@ __global__ void kernelFuncBfloat162Arithmetic() { bf162 = __hsub2_sat(bf162_1, bf162_2); } +__global__ void kernelFuncBfloat16Comparison() { + // CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2; + __nv_bfloat16 bf16_1, bf16_2; + bool b; + // CHECK: b = bf16_1 == bf16_2; + b = __heq(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::equal_to<>()); + b = __hequ(bf16_1, bf16_2); + // CHECK: b = bf16_1 >= bf16_2; + b = __hge(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater_equal<>()); + b = __hgeu(bf16_1, bf16_2); + // CHECK: b = bf16_1 > bf16_2; + b = __hgt(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater<>()); + b = __hgtu(bf16_1, bf16_2); + // CHECK: b = sycl::isinf(float(bf16_1)); + b = __hisinf(bf16_1); + // CHECK: b = sycl::isnan(float(bf16_1)); + b = __hisnan(bf16_1); + // CHECK: b = bf16_1 <= bf16_2; + b = __hle(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less_equal<>()); + b = __hleu(bf16_1, bf16_2); + // CHECK: b = bf16_1 < bf16_2; + b = __hlt(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less<>()); + b = __hltu(bf16_1, bf16_2); + // CHECK: b = sycl::fmax(float(bf16_1), float(bf16_2)); + b = __hmax(bf16_1, bf16_2); + // CHECK: b = dpct::fmax_nan(bf16_1, bf16_2); + b = __hmax_nan(bf16_1, bf16_2); + // CHECK: b = sycl::fmin(float(bf16_1), float(bf16_2)); + b = __hmin(bf16_1, bf16_2); + // CHECK: b = dpct::fmin_nan(bf16_1, bf16_2); + b = __hmin_nan(bf16_1, bf16_2); + // CHECK: b = dpct::compare(bf16_1, bf16_2, std::not_equal_to<>()); + b = __hne(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::not_equal_to<>()); + b = __hneu(bf16_1, bf16_2); +} + +__global__ void kernelFuncBfloat162Comparison() { + // CHECK: sycl::marray bf162, bf162_1, bf162_2; + __nv_bfloat162 bf162, bf162_1, bf162_2; + bool b; + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::equal_to<>()); + b = __hbeq2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::equal_to<>()); + b = __hbequ2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater_equal<>()); + b = __hbge2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater_equal<>()); + b = __hbgeu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater<>()); + b = __hbgt2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater<>()); + b = __hbgtu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less_equal<>()); + b = __hble2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less_equal<>()); + b = __hbleu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less<>()); + b = __hblt2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less<>()); + b = __hbltu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::not_equal_to<>()); + b = __hbne2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::not_equal_to<>()); + b = __hbneu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::equal_to<>()); + bf162 = __heq2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::equal_to<>()); + bf162 = __hequ2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater_equal<>()); + bf162 = __hge2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater_equal<>()); + bf162 = __hgeu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater<>()); + bf162 = __hgt2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater<>()); + bf162 = __hgtu2(bf162_1, bf162_2); + // CHECK: bf162 = sycl::marray(sycl::isnan(float(bf162_1[0])), sycl::isnan(float(bf162_1[1]))); + bf162 = __hisnan2(bf162_1); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less_equal<>()); + bf162 = __hle2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less_equal<>()); + bf162 = __hleu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less<>()); + bf162 = __hlt2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less<>()); + bf162 = __hltu2(bf162_1, bf162_2); + // CHECK: bf162 = sycl::marray(sycl::fmax(float(bf162_1[0]), float(bf162_2[0])), sycl::fmax(float(bf162_1[1]), float(bf162_2[1]))); + bf162 = __hmax2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::fmax_nan(bf162_1, bf162_2); + bf162 = __hmax2_nan(bf162_1, bf162_2); + // CHECK: bf162 = sycl::marray(sycl::fmin(float(bf162_1[0]), float(bf162_2[0])), sycl::fmin(float(bf162_1[1]), float(bf162_2[1]))); + bf162 = __hmin2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::fmin_nan(bf162_1, bf162_2); + bf162 = __hmin2_nan(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::not_equal_to<>()); + bf162 = __hne2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::not_equal_to<>()); + bf162 = __hneu2(bf162_1, bf162_2); +} + // CHECK: void test_conversions_device(sycl::ext::oneapi::bfloat16 *deviceArrayBFloat16) { // CHECK-NEXT: float f, f_1, f_2; // CHECK-NEXT: sycl::float2 f2, f2_1, f2_2; diff --git a/clang/test/dpct/math/bfloat16/bfloat16_cuda12_after.cu b/clang/test/dpct/math/bfloat16/bfloat16_cuda12_after.cu new file mode 100644 index 000000000000..f04b567c28c8 --- /dev/null +++ b/clang/test/dpct/math/bfloat16/bfloat16_cuda12_after.cu @@ -0,0 +1,41 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2, cuda-11.0, cuda-11.1, cuda-11.2, cuda-11.3, cuda-11.4, cuda-11.5, cuda-11.6, cuda-11.7, cuda-11.8 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2, v11.0, v11.1, v11.2, v11.3, v11.4, v11.5, v11.6, v11.7, v11.8 +// RUN: dpct --format-range=none --use-experimental-features=bfloat16_math_functions -out-root %T/math/bfloat16/bfloat16_cuda12_after %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/math/bfloat16/bfloat16_cuda12_after/bfloat16_cuda12_after.dp.cpp + +#include "cuda_bf16.h" + +__global__ void kernelFuncBfloat162Comparison() { + // CHECK: sycl::marray bf162, bf162_1, bf162_2; + __nv_bfloat162 bf162, bf162_1, bf162_2; + unsigned u; + + // Half2 Comparison Functions + + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::equal_to<>()); + u = __heq2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::equal_to<>()); + u = __hequ2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::greater_equal<>()); + u = __hge2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::greater_equal<>()); + u = __hgeu2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::greater<>()); + u = __hgt2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::greater<>()); + u = __hgtu2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::less_equal<>()); + u = __hle2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::less_equal<>()); + u = __hleu2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::less<>()); + u = __hlt2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::less<>()); + u = __hltu2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::not_equal_to<>()); + u = __hne2_mask(bf162_1, bf162_2); + // CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::not_equal_to<>()); + u = __hneu2_mask(bf162_1, bf162_2); +} + +int main() { return 0; } diff --git a/clang/test/dpct/math/bfloat16/bfloat16_experimental.cu b/clang/test/dpct/math/bfloat16/bfloat16_experimental.cu index ddaf2801260e..07550cfbc6d5 100644 --- a/clang/test/dpct/math/bfloat16/bfloat16_experimental.cu +++ b/clang/test/dpct/math/bfloat16/bfloat16_experimental.cu @@ -77,6 +77,112 @@ __global__ void kernelFuncBfloat162Arithmetic() { bf162 = __hsub2_sat(bf162_1, bf162_2); } +__global__ void kernelFuncBfloat16Comparison() { + // CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2; + __nv_bfloat16 bf16_1, bf16_2; + bool b; + // CHECK: b = bf16_1 == bf16_2; + b = __heq(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::equal_to<>()); + b = __hequ(bf16_1, bf16_2); + // CHECK: b = bf16_1 >= bf16_2; + b = __hge(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater_equal<>()); + b = __hgeu(bf16_1, bf16_2); + // CHECK: b = bf16_1 > bf16_2; + b = __hgt(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater<>()); + b = __hgtu(bf16_1, bf16_2); + // CHECK: b = sycl::isinf(float(bf16_1)); + b = __hisinf(bf16_1); + // CHECK: b = sycl::ext::oneapi::experimental::isnan(bf16_1); + b = __hisnan(bf16_1); + // CHECK: b = bf16_1 <= bf16_2; + b = __hle(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less_equal<>()); + b = __hleu(bf16_1, bf16_2); + // CHECK: b = bf16_1 < bf16_2; + b = __hlt(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less<>()); + b = __hltu(bf16_1, bf16_2); + // CHECK: b = sycl::ext::oneapi::experimental::fmax(bf16_1, bf16_2); + b = __hmax(bf16_1, bf16_2); + // CHECK: b = dpct::fmax_nan(bf16_1, bf16_2); + b = __hmax_nan(bf16_1, bf16_2); + // CHECK: b = sycl::ext::oneapi::experimental::fmin(bf16_1, bf16_2); + b = __hmin(bf16_1, bf16_2); + // CHECK: b = dpct::fmin_nan(bf16_1, bf16_2); + b = __hmin_nan(bf16_1, bf16_2); + // CHECK: b = dpct::compare(bf16_1, bf16_2, std::not_equal_to<>()); + b = __hne(bf16_1, bf16_2); + // CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::not_equal_to<>()); + b = __hneu(bf16_1, bf16_2); +} + +__global__ void kernelFuncBfloat162Comparison() { + // CHECK: sycl::marray bf162, bf162_1, bf162_2; + __nv_bfloat162 bf162, bf162_1, bf162_2; + bool b; + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::equal_to<>()); + b = __hbeq2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::equal_to<>()); + b = __hbequ2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater_equal<>()); + b = __hbge2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater_equal<>()); + b = __hbgeu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater<>()); + b = __hbgt2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater<>()); + b = __hbgtu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less_equal<>()); + b = __hble2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less_equal<>()); + b = __hbleu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less<>()); + b = __hblt2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less<>()); + b = __hbltu2(bf162_1, bf162_2); + // CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::not_equal_to<>()); + b = __hbne2(bf162_1, bf162_2); + // CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::not_equal_to<>()); + b = __hbneu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::equal_to<>()); + bf162 = __heq2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::equal_to<>()); + bf162 = __hequ2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater_equal<>()); + bf162 = __hge2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater_equal<>()); + bf162 = __hgeu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater<>()); + bf162 = __hgt2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater<>()); + bf162 = __hgtu2(bf162_1, bf162_2); + // CHECK: bf162 = sycl::ext::oneapi::experimental::isnan(bf162_1); + bf162 = __hisnan2(bf162_1); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less_equal<>()); + bf162 = __hle2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less_equal<>()); + bf162 = __hleu2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less<>()); + bf162 = __hlt2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less<>()); + bf162 = __hltu2(bf162_1, bf162_2); + // CHECK: bf162 = sycl::ext::oneapi::experimental::fmax(bf162_1, bf162_2); + bf162 = __hmax2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::fmax_nan(bf162_1, bf162_2); + bf162 = __hmax2_nan(bf162_1, bf162_2); + // CHECK: bf162 = sycl::ext::oneapi::experimental::fmin(bf162_1, bf162_2); + bf162 = __hmin2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::fmin_nan(bf162_1, bf162_2); + bf162 = __hmin2_nan(bf162_1, bf162_2); + // CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::not_equal_to<>()); + bf162 = __hne2(bf162_1, bf162_2); + // CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::not_equal_to<>()); + bf162 = __hneu2(bf162_1, bf162_2); +} + __global__ void kernelFuncBfloat16Math() { // CHECK: sycl::ext::oneapi::bfloat16 bf16, bf16_1; __nv_bfloat16 bf16, bf16_1; diff --git a/clang/test/dpct/math/bfloat16/bfloat16_ext.cu b/clang/test/dpct/math/bfloat16/bfloat16_ext.cu index 3edb1f95bc7c..09cb4bab1ddc 100644 --- a/clang/test/dpct/math/bfloat16/bfloat16_ext.cu +++ b/clang/test/dpct/math/bfloat16/bfloat16_ext.cu @@ -12,4 +12,12 @@ __global__ void kernelFuncBfloat162Arithmetic() { bf162 = __h2div(bf162_1, bf162_2); } +__global__ void kernelFuncBfloat16Comparison() { + // CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2; + __nv_bfloat16 bf16_1, bf16_2; + bool b; + // CHECK: b = bf16_1 == bf16_2; + b = __heq(bf16_1, bf16_2); +} + int main() { return 0; } diff --git a/clang/test/dpct/math/cuda-math-extension.cu b/clang/test/dpct/math/cuda-math-extension.cu index 1f737d43fd48..bfb42868de0c 100644 --- a/clang/test/dpct/math/cuda-math-extension.cu +++ b/clang/test/dpct/math/cuda-math-extension.cu @@ -149,6 +149,10 @@ __global__ void kernelFuncHalf() { b = __hgt(h, h_1); // CHECK: b = sycl::ext::intel::math::hgtu(h, h_1); b = __hgtu(h, h_1); + // CHECK: b = sycl::ext::intel::math::hisinf(h); + b = __hisinf(h); + // CHECK: b = sycl::ext::intel::math::hisnan(h); + b = __hisnan(h); // CHECK: b = sycl::ext::intel::math::hle(h, h_1); b = __hle(h, h_1); // CHECK: b = sycl::ext::intel::math::hleu(h, h_1);