Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/lib/DPCT/APINamesMath.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ ENTRY_REWRITE("__hfma")
ENTRY_REWRITE("__hfma2")

// Half Comparison Functions
ENTRY_RENAMED("__hisinf", MapNames::getClNamespace(false, true) + "isinf")
ENTRY_RENAMED("__hisnan", MapNames::getClNamespace(false, true) + "isnan")
ENTRY_REWRITE("__hisinf")
ENTRY_REWRITE("__hisnan")

// Half Math Functions
ENTRY_REWRITE("hceil")
Expand Down
766 changes: 496 additions & 270 deletions clang/lib/DPCT/APINamesMathRewrite.inc

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions clang/runtime/dpct-rt/include/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ inline unsigned compare_mask(const sycl::vec<T, 2> a, const sycl::vec<T, 2> b,
-compare(a[1], b[1], binary_op))
.as<sycl::vec<unsigned, 1>>();
}
template <typename T, class BinaryOperation>
inline unsigned compare_mask(const sycl::marray<T, 2> a,
const sycl::marray<T, 2> b,
const BinaryOperation binary_op) {
return sycl::vec<short, 2>(-compare(a[0], b[0], binary_op),
-compare(a[1], b[1], binary_op))
.as<sycl::vec<unsigned, 1>>();
}

/// Performs 2 element unordered comparison.
/// \param [in] a The first value
Expand Down Expand Up @@ -263,6 +271,14 @@ inline unsigned unordered_compare_mask(const sycl::vec<T, 2> a,
-unordered_compare(a[1], b[1], binary_op))
.as<sycl::vec<unsigned, 1>>();
}
template <typename T, class BinaryOperation>
inline unsigned unordered_compare_mask(const sycl::marray<T, 2> a,
const sycl::marray<T, 2> b,
const BinaryOperation binary_op) {
return sycl::vec<short, 2>(-unordered_compare(a[0], b[0], binary_op),
-unordered_compare(a[1], b[1], binary_op))
.as<sycl::vec<unsigned, 1>>();
}

/// Determine whether 2 element value is NaN.
/// \param [in] a The input value
Expand Down Expand Up @@ -433,11 +449,26 @@ template <typename T> inline T fmax_nan(const T a, const T b) {
return NAN;
return sycl::fmax(a, b);
}
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
template <>
inline sycl::ext::oneapi::bfloat16
fmax_nan(const sycl::ext::oneapi::bfloat16 a,
const sycl::ext::oneapi::bfloat16 b) {
if (detail::isnan(a) || detail::isnan(b))
return NAN;
return sycl::fmax(float(a), float(b));
}
#endif
template <typename T>
inline sycl::vec<T, 2> fmax_nan(const sycl::vec<T, 2> a,
const sycl::vec<T, 2> b) {
return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])};
}
template <typename T>
inline sycl::marray<T, 2> fmax_nan(const sycl::marray<T, 2> a,
const sycl::marray<T, 2> b) {
return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])};
}

/// Performs 2 elements comparison and returns the smaller one. If either of
/// inputs is NaN, then return NaN.
Expand All @@ -449,11 +480,26 @@ template <typename T> inline T fmin_nan(const T a, const T b) {
return NAN;
return sycl::fmin(a, b);
}
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
template <>
inline sycl::ext::oneapi::bfloat16
fmin_nan(const sycl::ext::oneapi::bfloat16 a,
const sycl::ext::oneapi::bfloat16 b) {
if (detail::isnan(a) || detail::isnan(b))
return NAN;
return sycl::fmin(float(a), float(b));
}
#endif
template <typename T>
inline sycl::vec<T, 2> fmin_nan(const sycl::vec<T, 2> a,
const sycl::vec<T, 2> b) {
return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])};
}
template <typename T>
inline sycl::marray<T, 2> fmin_nan(const sycl::marray<T, 2> a,
const sycl::marray<T, 2> b) {
return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])};
}

/// A sycl::abs wrapper functors.
struct abs {
Expand Down
106 changes: 106 additions & 0 deletions clang/test/dpct/math/bfloat16/bfloat16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,112 @@ __global__ void kernelFuncBfloat162Arithmetic() {
bf162 = __hsub2_sat(bf162_1, bf162_2);
}

__global__ void kernelFuncBfloat16Comparison() {
// CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2;
__nv_bfloat16 bf16_1, bf16_2;
bool b;
// CHECK: b = bf16_1 == bf16_2;
b = __heq(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::equal_to<>());
b = __hequ(bf16_1, bf16_2);
// CHECK: b = bf16_1 >= bf16_2;
b = __hge(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater_equal<>());
b = __hgeu(bf16_1, bf16_2);
// CHECK: b = bf16_1 > bf16_2;
b = __hgt(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater<>());
b = __hgtu(bf16_1, bf16_2);
// CHECK: b = sycl::isinf(float(bf16_1));
b = __hisinf(bf16_1);
// CHECK: b = sycl::isnan(float(bf16_1));
b = __hisnan(bf16_1);
// CHECK: b = bf16_1 <= bf16_2;
b = __hle(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less_equal<>());
b = __hleu(bf16_1, bf16_2);
// CHECK: b = bf16_1 < bf16_2;
b = __hlt(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less<>());
b = __hltu(bf16_1, bf16_2);
// CHECK: b = sycl::fmax(float(bf16_1), float(bf16_2));
b = __hmax(bf16_1, bf16_2);
// CHECK: b = dpct::fmax_nan(bf16_1, bf16_2);
b = __hmax_nan(bf16_1, bf16_2);
// CHECK: b = sycl::fmin(float(bf16_1), float(bf16_2));
b = __hmin(bf16_1, bf16_2);
// CHECK: b = dpct::fmin_nan(bf16_1, bf16_2);
b = __hmin_nan(bf16_1, bf16_2);
// CHECK: b = dpct::compare(bf16_1, bf16_2, std::not_equal_to<>());
b = __hne(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::not_equal_to<>());
b = __hneu(bf16_1, bf16_2);
}

__global__ void kernelFuncBfloat162Comparison() {
// CHECK: sycl::marray<sycl::ext::oneapi::bfloat16, 2> bf162, bf162_1, bf162_2;
__nv_bfloat162 bf162, bf162_1, bf162_2;
bool b;
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::equal_to<>());
b = __hbeq2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::equal_to<>());
b = __hbequ2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater_equal<>());
b = __hbge2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater_equal<>());
b = __hbgeu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater<>());
b = __hbgt2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater<>());
b = __hbgtu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less_equal<>());
b = __hble2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less_equal<>());
b = __hbleu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less<>());
b = __hblt2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less<>());
b = __hbltu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::not_equal_to<>());
b = __hbne2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::not_equal_to<>());
b = __hbneu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::equal_to<>());
bf162 = __heq2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::equal_to<>());
bf162 = __hequ2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater_equal<>());
bf162 = __hge2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater_equal<>());
bf162 = __hgeu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater<>());
bf162 = __hgt2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater<>());
bf162 = __hgtu2(bf162_1, bf162_2);
// CHECK: bf162 = sycl::marray<sycl::ext::oneapi::bfloat16, 2>(sycl::isnan(float(bf162_1[0])), sycl::isnan(float(bf162_1[1])));
bf162 = __hisnan2(bf162_1);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less_equal<>());
bf162 = __hle2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less_equal<>());
bf162 = __hleu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less<>());
bf162 = __hlt2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less<>());
bf162 = __hltu2(bf162_1, bf162_2);
// CHECK: bf162 = sycl::marray<sycl::ext::oneapi::bfloat16, 2>(sycl::fmax(float(bf162_1[0]), float(bf162_2[0])), sycl::fmax(float(bf162_1[1]), float(bf162_2[1])));
bf162 = __hmax2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::fmax_nan(bf162_1, bf162_2);
bf162 = __hmax2_nan(bf162_1, bf162_2);
// CHECK: bf162 = sycl::marray<sycl::ext::oneapi::bfloat16, 2>(sycl::fmin(float(bf162_1[0]), float(bf162_2[0])), sycl::fmin(float(bf162_1[1]), float(bf162_2[1])));
bf162 = __hmin2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::fmin_nan(bf162_1, bf162_2);
bf162 = __hmin2_nan(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::not_equal_to<>());
bf162 = __hne2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::not_equal_to<>());
bf162 = __hneu2(bf162_1, bf162_2);
}

// CHECK: void test_conversions_device(sycl::ext::oneapi::bfloat16 *deviceArrayBFloat16) {
// CHECK-NEXT: float f, f_1, f_2;
// CHECK-NEXT: sycl::float2 f2, f2_1, f2_2;
Expand Down
41 changes: 41 additions & 0 deletions clang/test/dpct/math/bfloat16/bfloat16_cuda12_after.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2, cuda-11.0, cuda-11.1, cuda-11.2, cuda-11.3, cuda-11.4, cuda-11.5, cuda-11.6, cuda-11.7, cuda-11.8
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2, v11.0, v11.1, v11.2, v11.3, v11.4, v11.5, v11.6, v11.7, v11.8
// RUN: dpct --format-range=none --use-experimental-features=bfloat16_math_functions -out-root %T/math/bfloat16/bfloat16_cuda12_after %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only
// RUN: FileCheck %s --match-full-lines --input-file %T/math/bfloat16/bfloat16_cuda12_after/bfloat16_cuda12_after.dp.cpp

#include "cuda_bf16.h"

__global__ void kernelFuncBfloat162Comparison() {
// CHECK: sycl::marray<sycl::ext::oneapi::bfloat16, 2> bf162, bf162_1, bf162_2;
__nv_bfloat162 bf162, bf162_1, bf162_2;
unsigned u;

// Half2 Comparison Functions

// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::equal_to<>());
u = __heq2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::equal_to<>());
u = __hequ2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::greater_equal<>());
u = __hge2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::greater_equal<>());
u = __hgeu2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::greater<>());
u = __hgt2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::greater<>());
u = __hgtu2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::less_equal<>());
u = __hle2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::less_equal<>());
u = __hleu2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::less<>());
u = __hlt2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::less<>());
u = __hltu2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::compare_mask(bf162_1, bf162_2, std::not_equal_to<>());
u = __hne2_mask(bf162_1, bf162_2);
// CHECK: u = dpct::unordered_compare_mask(bf162_1, bf162_2, std::not_equal_to<>());
u = __hneu2_mask(bf162_1, bf162_2);
}

int main() { return 0; }
106 changes: 106 additions & 0 deletions clang/test/dpct/math/bfloat16/bfloat16_experimental.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,112 @@ __global__ void kernelFuncBfloat162Arithmetic() {
bf162 = __hsub2_sat(bf162_1, bf162_2);
}

__global__ void kernelFuncBfloat16Comparison() {
// CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2;
__nv_bfloat16 bf16_1, bf16_2;
bool b;
// CHECK: b = bf16_1 == bf16_2;
b = __heq(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::equal_to<>());
b = __hequ(bf16_1, bf16_2);
// CHECK: b = bf16_1 >= bf16_2;
b = __hge(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater_equal<>());
b = __hgeu(bf16_1, bf16_2);
// CHECK: b = bf16_1 > bf16_2;
b = __hgt(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::greater<>());
b = __hgtu(bf16_1, bf16_2);
// CHECK: b = sycl::isinf(float(bf16_1));
b = __hisinf(bf16_1);
// CHECK: b = sycl::ext::oneapi::experimental::isnan(bf16_1);
b = __hisnan(bf16_1);
// CHECK: b = bf16_1 <= bf16_2;
b = __hle(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less_equal<>());
b = __hleu(bf16_1, bf16_2);
// CHECK: b = bf16_1 < bf16_2;
b = __hlt(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::less<>());
b = __hltu(bf16_1, bf16_2);
// CHECK: b = sycl::ext::oneapi::experimental::fmax(bf16_1, bf16_2);
b = __hmax(bf16_1, bf16_2);
// CHECK: b = dpct::fmax_nan(bf16_1, bf16_2);
b = __hmax_nan(bf16_1, bf16_2);
// CHECK: b = sycl::ext::oneapi::experimental::fmin(bf16_1, bf16_2);
b = __hmin(bf16_1, bf16_2);
// CHECK: b = dpct::fmin_nan(bf16_1, bf16_2);
b = __hmin_nan(bf16_1, bf16_2);
// CHECK: b = dpct::compare(bf16_1, bf16_2, std::not_equal_to<>());
b = __hne(bf16_1, bf16_2);
// CHECK: b = dpct::unordered_compare(bf16_1, bf16_2, std::not_equal_to<>());
b = __hneu(bf16_1, bf16_2);
}

__global__ void kernelFuncBfloat162Comparison() {
// CHECK: sycl::marray<sycl::ext::oneapi::bfloat16, 2> bf162, bf162_1, bf162_2;
__nv_bfloat162 bf162, bf162_1, bf162_2;
bool b;
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::equal_to<>());
b = __hbeq2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::equal_to<>());
b = __hbequ2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater_equal<>());
b = __hbge2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater_equal<>());
b = __hbgeu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::greater<>());
b = __hbgt2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::greater<>());
b = __hbgtu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less_equal<>());
b = __hble2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less_equal<>());
b = __hbleu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::less<>());
b = __hblt2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::less<>());
b = __hbltu2(bf162_1, bf162_2);
// CHECK: b = dpct::compare_both(bf162_1, bf162_2, std::not_equal_to<>());
b = __hbne2(bf162_1, bf162_2);
// CHECK: b = dpct::unordered_compare_both(bf162_1, bf162_2, std::not_equal_to<>());
b = __hbneu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::equal_to<>());
bf162 = __heq2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::equal_to<>());
bf162 = __hequ2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater_equal<>());
bf162 = __hge2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater_equal<>());
bf162 = __hgeu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::greater<>());
bf162 = __hgt2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::greater<>());
bf162 = __hgtu2(bf162_1, bf162_2);
// CHECK: bf162 = sycl::ext::oneapi::experimental::isnan(bf162_1);
bf162 = __hisnan2(bf162_1);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less_equal<>());
bf162 = __hle2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less_equal<>());
bf162 = __hleu2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::less<>());
bf162 = __hlt2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::less<>());
bf162 = __hltu2(bf162_1, bf162_2);
// CHECK: bf162 = sycl::ext::oneapi::experimental::fmax(bf162_1, bf162_2);
bf162 = __hmax2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::fmax_nan(bf162_1, bf162_2);
bf162 = __hmax2_nan(bf162_1, bf162_2);
// CHECK: bf162 = sycl::ext::oneapi::experimental::fmin(bf162_1, bf162_2);
bf162 = __hmin2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::fmin_nan(bf162_1, bf162_2);
bf162 = __hmin2_nan(bf162_1, bf162_2);
// CHECK: bf162 = dpct::compare(bf162_1, bf162_2, std::not_equal_to<>());
bf162 = __hne2(bf162_1, bf162_2);
// CHECK: bf162 = dpct::unordered_compare(bf162_1, bf162_2, std::not_equal_to<>());
bf162 = __hneu2(bf162_1, bf162_2);
}

__global__ void kernelFuncBfloat16Math() {
// CHECK: sycl::ext::oneapi::bfloat16 bf16, bf16_1;
__nv_bfloat16 bf16, bf16_1;
Expand Down
8 changes: 8 additions & 0 deletions clang/test/dpct/math/bfloat16/bfloat16_ext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,12 @@ __global__ void kernelFuncBfloat162Arithmetic() {
bf162 = __h2div(bf162_1, bf162_2);
}

__global__ void kernelFuncBfloat16Comparison() {
// CHECK: sycl::ext::oneapi::bfloat16 bf16_1, bf16_2;
__nv_bfloat16 bf16_1, bf16_2;
bool b;
// CHECK: b = bf16_1 == bf16_2;
b = __heq(bf16_1, bf16_2);
}

int main() { return 0; }
4 changes: 4 additions & 0 deletions clang/test/dpct/math/cuda-math-extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ __global__ void kernelFuncHalf() {
b = __hgt(h, h_1);
// CHECK: b = sycl::ext::intel::math::hgtu(h, h_1);
b = __hgtu(h, h_1);
// CHECK: b = sycl::ext::intel::math::hisinf(h);
b = __hisinf(h);
// CHECK: b = sycl::ext::intel::math::hisnan(h);
b = __hisnan(h);
// CHECK: b = sycl::ext::intel::math::hle(h, h_1);
b = __hle(h, h_1);
// CHECK: b = sycl::ext::intel::math::hleu(h, h_1);
Expand Down