Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5a000ac
No constexpr construction in math tests
joeatodd Sep 25, 2024
c32b263
Add syclcompat::is_floating_point_v
joeatodd Sep 26, 2024
9f28f62
Add bfloat16 to tests and generalize to container types
joeatodd Sep 27, 2024
8319654
Add some guards & TODOs to math.hpp
joeatodd Sep 27, 2024
ec00197
Generalize should_skip to variadic types
joeatodd Sep 27, 2024
f8f62aa
Specialize std::common_type for bfloat16
joeatodd Sep 30, 2024
dad728c
Add bfloat16 support (& more extensive tests) to fmin/max_nan
joeatodd Sep 30, 2024
9681cc7
Define constexpr bool for bfloat16 support
joeatodd Sep 30, 2024
a5e8e10
Add marray fmin/fmax_nan to docs
joeatodd Sep 30, 2024
4497d82
Extend isnan support to bfloat16
joeatodd Sep 30, 2024
1fe5980
Formatting
joeatodd Sep 30, 2024
847d45d
Fix bug in cmul_add tests
joeatodd Sep 30, 2024
55e95ad
Add cmul_add<bfloat> & draft test
joeatodd Oct 1, 2024
5b0e0b5
Enable relu for bfloat
joeatodd Oct 1, 2024
51860b3
Add bfloat16-container support to clamp
joeatodd Oct 1, 2024
d16dd56
Add max & min bfloat16 support & tests
joeatodd Oct 1, 2024
05b0748
Generalize compare_mask & unordered_compare_mask
joeatodd Oct 2, 2024
fd9ebf4
Tidy a comment
joeatodd Oct 2, 2024
e5f9231
Add bfloat16 support to cbrt
joeatodd Oct 2, 2024
f95dc85
Revert "Add cmul_add<bfloat> & draft test"
joeatodd Oct 2, 2024
d7b8676
Merge branch 'sycl' into jtodd/bfloat16_support
joeatodd Oct 14, 2024
4da9833
Move std::common_type specialization to traits.hpp
joeatodd Oct 15, 2024
ac2b4e9
Document std::common_type_t<bfloat16,...>
joeatodd Oct 16, 2024
b0303bb
Revert unneeded cbrt<bfloat16> support
joeatodd Oct 16, 2024
fc11764
Review fix includes
joeatodd Oct 16, 2024
f21f62e
Assert bfloat16 math support in `isnan`, `max`, and `min`
joeatodd Oct 16, 2024
48e786a
Merge branch 'sycl' into jtodd/bfloat16_support
joeatodd Oct 16, 2024
20012b3
Fix incorrect local memory usage in tests
joeatodd Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions sycl/doc/syclcompat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,51 @@ second operand, respectively. These three APIs return a single 32-bit value with
the accumulated result, which is unsigned if both operands are `uint32_t` and
signed otherwise.

Various maths functions are defined operate on any floating point types.
`syclcompat::is_floating_point_v` extends the standard library's
`std::is_floating_point_v` to include `sycl::half` and, where available,
`sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides
a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`,
though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in
future.

```cpp
namespace std {
template <> struct common_type<sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};

template <>
struct common_type<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};

template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
using type = sycl::ext::oneapi::bfloat16;
};

template <typename T> struct common_type<T, sycl::ext::oneapi::bfloat16> {
using type = sycl::ext::oneapi::bfloat16;
};
} // namespace std
```

```cpp
namespace syclcompat{

// Trait for extended floating point definition
template <typename T>
struct is_floating_point : std::is_floating_point<T>{};

template <> struct is_floating_point<sycl::half> : std::true_type {};

#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
template <> struct is_floating_point<sycl::ext::oneapi::bfloat16> : std::true_type {};
#endif
template <typename T>

inline constexpr bool is_floating_point_v = is_floating_point<T>::value;

inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
unsigned int shift);

Expand All @@ -1752,11 +1796,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
// cbrt function wrapper.
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
std::is_same_v<ValueT, sycl::half>,
ValueT>
cbrt(ValueT val) {
return sycl::cbrt(static_cast<ValueT>(val));
}
cbrt(ValueT val);

// For floating-point types, `float` or `double` arguments are acceptable.
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
Expand Down Expand Up @@ -1794,6 +1836,10 @@ template <typename ValueT, typename ValueU>
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
fmax_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);

template <typename ValueT, typename ValueU>
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
fmax_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);

// Performs 2 elements comparison and returns the smaller one. If either of
// inputs is NaN, then return NaN.
template <typename ValueT, typename ValueU>
Expand All @@ -1803,6 +1849,10 @@ template <typename ValueT, typename ValueU>
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
fmin_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);

template <typename ValueT, typename ValueU>
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
fmin_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);

inline float pow(const float a, const int b) { return sycl::pown(a, b); }
inline double pow(const double a, const int b) { return sycl::pown(a, b); }

Expand Down Expand Up @@ -1863,14 +1913,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
const BinaryOperation binary_op);

template <typename ValueT, class BinaryOperation>
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
const sycl::vec<ValueT, 2> b,
const BinaryOperation binary_op);
inline std::enable_if_t<ValueT::size() == 2, unsigned>
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);

template <typename ValueT, class BinaryOperation>
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
const sycl::vec<ValueT, 2> b,
const BinaryOperation binary_op);
inline std::enable_if_t<ValueT::size() == 2, unsigned>
unordered_compare_mask(const ValueT a, const ValueT b,
const BinaryOperation binary_op);

template <typename S, typename T> inline T vectorized_max(T a, T b);

Expand Down Expand Up @@ -1924,6 +1973,7 @@ inline dot_product_acc_t<T1, T2> dp2a_hi(T1 a, T2 b,
template <typename T1, typename T2>
inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
dot_product_acc_t<T1, T2> c);
} // namespace syclcompat
```

`vectorized_binary` computes the `BinaryOperation` for two operands,
Expand Down
Loading
Loading