Skip to content

Commit a4ff76e

Browse files
[libc][math][c++23] Implement basic arithmetic operations for BFloat16 (#151228)
This PR implements addition, subtraction, multiplication and division operations for BFloat16. --------- Signed-off-by: krishna2803 <kpandey81930@gmail.com> Signed-off-by: Krishna Pandey <kpandey81930@gmail.com> Co-authored-by: OverMighty <its.overmighty@gmail.com>
1 parent 2bb23d4 commit a4ff76e

21 files changed

+554
-49
lines changed

libc/src/__support/FPUtil/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ add_header_library(
285285
libc.hdr.stdint_proxy
286286
libc.src.__support.CPP.bit
287287
libc.src.__support.CPP.type_traits
288+
libc.src.__support.FPUtil.generic.add_sub
289+
libc.src.__support.FPUtil.generic.div
290+
libc.src.__support.FPUtil.generic.mul
288291
libc.src.__support.macros.config
289292
libc.src.__support.macros.properties.types
290293
)

libc/src/__support/FPUtil/bfloat16.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include "src/__support/FPUtil/cast.h"
1616
#include "src/__support/FPUtil/comparison_operations.h"
1717
#include "src/__support/FPUtil/dyadic_float.h"
18+
#include "src/__support/FPUtil/generic/add_sub.h"
19+
#include "src/__support/FPUtil/generic/div.h"
20+
#include "src/__support/FPUtil/generic/mul.h"
1821
#include "src/__support/macros/config.h"
1922
#include "src/__support/macros/properties/types.h"
2023

@@ -81,6 +84,28 @@ struct BFloat16 {
8184
LIBC_INLINE bool operator>=(BFloat16 other) const {
8285
return fputil::greater_than_or_equals(*this, other);
8386
}
87+
88+
LIBC_INLINE constexpr BFloat16 operator-() const {
89+
fputil::FPBits<bfloat16> result(*this);
90+
result.set_sign(result.is_pos() ? Sign::NEG : Sign::POS);
91+
return result.get_val();
92+
}
93+
94+
LIBC_INLINE BFloat16 operator+(BFloat16 other) const {
95+
return fputil::generic::add<BFloat16>(*this, other);
96+
}
97+
98+
LIBC_INLINE BFloat16 operator-(BFloat16 other) const {
99+
return fputil::generic::sub<BFloat16>(*this, other);
100+
}
101+
102+
LIBC_INLINE BFloat16 operator*(BFloat16 other) const {
103+
return fputil::generic::mul<bfloat16>(*this, other);
104+
}
105+
106+
LIBC_INLINE BFloat16 operator/(BFloat16 other) const {
107+
return fputil::generic::div<bfloat16>(*this, other);
108+
}
84109
}; // struct BFloat16
85110

86111
} // namespace fputil

libc/src/__support/FPUtil/cast.h

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,47 @@ LIBC_INLINE constexpr cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
2727
OutType>
2828
cast(InType x) {
2929
// Casting to the same type is a no-op.
30-
if constexpr (cpp::is_same_v<InType, OutType>)
30+
if constexpr (cpp::is_same_v<InType, OutType>) {
3131
return x;
32-
33-
// bfloat16 is always defined (for now)
34-
if constexpr (cpp::is_same_v<OutType, bfloat16> ||
35-
cpp::is_same_v<InType, bfloat16>
32+
} else {
33+
if constexpr (cpp::is_same_v<OutType, bfloat16> ||
34+
cpp::is_same_v<InType, bfloat16>
3635
#if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
37-
|| cpp::is_same_v<OutType, float16> ||
38-
cpp::is_same_v<InType, float16>
36+
|| cpp::is_same_v<OutType, float16> ||
37+
cpp::is_same_v<InType, float16>
3938
#endif
40-
) {
41-
using InFPBits = FPBits<InType>;
42-
using InStorageType = typename InFPBits::StorageType;
43-
using OutFPBits = FPBits<OutType>;
44-
using OutStorageType = typename OutFPBits::StorageType;
39+
) {
40+
using InFPBits = FPBits<InType>;
41+
using InStorageType = typename InFPBits::StorageType;
42+
using OutFPBits = FPBits<OutType>;
43+
using OutStorageType = typename OutFPBits::StorageType;
4544

46-
InFPBits x_bits(x);
45+
InFPBits x_bits(x);
4746

48-
if (x_bits.is_nan()) {
49-
if (x_bits.is_signaling_nan()) {
50-
raise_except_if_required(FE_INVALID);
51-
return OutFPBits::quiet_nan().get_val();
52-
}
47+
if (x_bits.is_nan()) {
48+
if (x_bits.is_signaling_nan()) {
49+
raise_except_if_required(FE_INVALID);
50+
return OutFPBits::quiet_nan().get_val();
51+
}
5352

54-
InStorageType x_mant = x_bits.get_mantissa();
55-
if (InFPBits::FRACTION_LEN > OutFPBits::FRACTION_LEN)
56-
x_mant >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
57-
return OutFPBits::quiet_nan(x_bits.sign(),
58-
static_cast<OutStorageType>(x_mant))
59-
.get_val();
60-
}
53+
InStorageType x_mant = x_bits.get_mantissa();
54+
if (InFPBits::FRACTION_LEN > OutFPBits::FRACTION_LEN)
55+
x_mant >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
56+
return OutFPBits::quiet_nan(x_bits.sign(),
57+
static_cast<OutStorageType>(x_mant))
58+
.get_val();
59+
}
6160

62-
if (x_bits.is_inf())
63-
return OutFPBits::inf(x_bits.sign()).get_val();
61+
if (x_bits.is_inf())
62+
return OutFPBits::inf(x_bits.sign()).get_val();
6463

65-
constexpr size_t MAX_FRACTION_LEN =
66-
cpp::max(OutFPBits::FRACTION_LEN, InFPBits::FRACTION_LEN);
67-
DyadicFloat<cpp::bit_ceil(MAX_FRACTION_LEN)> xd(x);
68-
return xd.template as<OutType, /*ShouldSignalExceptions=*/true>();
69-
} else {
70-
return static_cast<OutType>(x);
64+
constexpr size_t MAX_FRACTION_LEN =
65+
cpp::max(OutFPBits::FRACTION_LEN, InFPBits::FRACTION_LEN);
66+
DyadicFloat<cpp::bit_ceil(MAX_FRACTION_LEN)> xd(x);
67+
return xd.template as<OutType, /*ShouldSignalExceptions=*/true>();
68+
} else {
69+
return static_cast<OutType>(x);
70+
}
7171
}
7272
}
7373

libc/src/__support/FPUtil/dyadic_float.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(const DyadicFloat<Bits> &a,
576576
// Check the leading bit directly, should be faster than using clz in
577577
// normalize().
578578
if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
579-
63 ==
579+
(DyadicFloat<Bits>::MantissaType::WORD_SIZE - 1) ==
580580
0)
581581
result.shift_left(1);
582582
} else {

libc/src/__support/FPUtil/generic/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ add_header_library(
6868
libc.src.__support.FPUtil.rounding_mode
6969
libc.src.__support.macros.attributes
7070
libc.src.__support.macros.optimization
71+
libc.src.__support.macros.properties.types
7172
)
7273

7374
add_header_library(
@@ -77,6 +78,7 @@ add_header_library(
7778
DEPENDS
7879
libc.hdr.errno_macros
7980
libc.hdr.fenv_macros
81+
libc.src.__support.CPP.algorithm
8082
libc.src.__support.CPP.bit
8183
libc.src.__support.CPP.type_traits
8284
libc.src.__support.FPUtil.basic_operations

libc/src/__support/FPUtil/generic/add_sub.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,22 @@ add_or_sub(InType x, InType y) {
104104
}
105105
}
106106

107-
// volatile prevents Clang from converting tmp to OutType and then
108-
// immediately back to InType before negating it, resulting in double
109-
// rounding.
110-
volatile InType tmp = y;
111-
if constexpr (IsSub)
112-
tmp = -tmp;
113-
return cast<OutType>(tmp);
107+
if constexpr (cpp::is_same_v<InType, bfloat16> &&
108+
cpp::is_same_v<OutType, bfloat16>) {
109+
OutFPBits y_bits(y);
110+
if constexpr (IsSub)
111+
y_bits.set_sign(y_bits.sign().negate());
112+
return y_bits.get_val();
113+
} else {
114+
115+
// volatile prevents Clang from converting tmp to OutType and then
116+
// immediately back to InType before negating it, resulting in double
117+
// rounding.
118+
volatile InType tmp = y;
119+
if constexpr (IsSub)
120+
tmp = -tmp;
121+
return cast<OutType>(tmp);
122+
}
114123
}
115124

116125
if (y_bits.is_zero())

libc/src/__support/FPUtil/generic/div.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "hdr/errno_macros.h"
1313
#include "hdr/fenv_macros.h"
14+
#include "src/__support/CPP/algorithm.h"
1415
#include "src/__support/CPP/bit.h"
1516
#include "src/__support/CPP/type_traits.h"
1617
#include "src/__support/FPUtil/BasicOperations.h"
@@ -34,8 +35,9 @@ div(InType x, InType y) {
3435
using OutStorageType = typename OutFPBits::StorageType;
3536
using InFPBits = FPBits<InType>;
3637
using InStorageType = typename InFPBits::StorageType;
37-
using DyadicFloat =
38-
DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::SIG_LEN + 1))>;
38+
using DyadicFloat = DyadicFloat<cpp::max(
39+
static_cast<size_t>(16),
40+
cpp::bit_ceil(static_cast<size_t>(InFPBits::SIG_LEN + 1)))>;
3941

4042
InFPBits x_bits(x);
4143
InFPBits y_bits(y);

libc/test/src/math/exhaustive/CMakeLists.txt

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,75 @@ add_fp_unittest(
567567
LINK_LIBRARIES
568568
-lpthread
569569
)
570+
571+
add_fp_unittest(
572+
bfloat16_add_test
573+
NO_RUN_POSTBUILD
574+
NEED_MPFR
575+
SUITE
576+
libc_math_exhaustive_tests
577+
SRCS
578+
bfloat16_add_test.cpp
579+
COMPILE_OPTIONS
580+
${libc_opt_high_flag}
581+
DEPENDS
582+
.exhaustive_test
583+
libc.src.__support.FPUtil.bfloat16
584+
libc.src.__support.FPUtil.fp_bits
585+
LINK_LIBRARIES
586+
-lpthread
587+
)
588+
589+
add_fp_unittest(
590+
bfloat16_div_test
591+
NO_RUN_POSTBUILD
592+
NEED_MPFR
593+
SUITE
594+
libc_math_exhaustive_tests
595+
SRCS
596+
bfloat16_div_test.cpp
597+
COMPILE_OPTIONS
598+
${libc_opt_high_flag}
599+
DEPENDS
600+
.exhaustive_test
601+
libc.src.__support.FPUtil.bfloat16
602+
libc.src.__support.FPUtil.fp_bits
603+
LINK_LIBRARIES
604+
-lpthread
605+
)
606+
607+
add_fp_unittest(
608+
bfloat16_mul_test
609+
NO_RUN_POSTBUILD
610+
NEED_MPFR
611+
SUITE
612+
libc_math_exhaustive_tests
613+
SRCS
614+
bfloat16_mul_test.cpp
615+
COMPILE_OPTIONS
616+
${libc_opt_high_flag}
617+
DEPENDS
618+
.exhaustive_test
619+
libc.src.__support.FPUtil.bfloat16
620+
libc.src.__support.FPUtil.fp_bits
621+
LINK_LIBRARIES
622+
-lpthread
623+
)
624+
625+
add_fp_unittest(
626+
bfloat16_sub_test
627+
NO_RUN_POSTBUILD
628+
NEED_MPFR
629+
SUITE
630+
libc_math_exhaustive_tests
631+
SRCS
632+
bfloat16_sub_test.cpp
633+
COMPILE_OPTIONS
634+
${libc_opt_high_flag}
635+
DEPENDS
636+
.exhaustive_test
637+
libc.src.__support.FPUtil.bfloat16
638+
libc.src.__support.FPUtil.fp_bits
639+
LINK_LIBRARIES
640+
-lpthread
641+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===-- Exhaustive tests for bfloat16 addition ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "exhaustive_test.h"
10+
#include "src/__support/FPUtil/FPBits.h"
11+
#include "src/__support/FPUtil/bfloat16.h"
12+
#include "test/UnitTest/FPMatcher.h"
13+
#include "utils/MPFRWrapper/MPCommon.h"
14+
#include "utils/MPFRWrapper/MPFRUtils.h"
15+
16+
namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
17+
using LIBC_NAMESPACE::fputil::BFloat16;
18+
19+
static BFloat16 add_func(BFloat16 x, BFloat16 y) { return x + y; }
20+
21+
struct Bfloat16AddChecker : public virtual LIBC_NAMESPACE::testing::Test {
22+
using FloatType = BFloat16;
23+
using FPBits = LIBC_NAMESPACE::fputil::FPBits<bfloat16>;
24+
using StorageType = typename FPBits::StorageType;
25+
26+
uint64_t check(uint16_t x_start, uint16_t x_stop, uint16_t y_start,
27+
uint16_t y_stop, mpfr::RoundingMode rounding) {
28+
mpfr::ForceRoundingMode r(rounding);
29+
if (!r.success)
30+
return true;
31+
uint16_t xbits = x_start;
32+
uint64_t failed = 0;
33+
do {
34+
BFloat16 x = FPBits(xbits).get_val();
35+
uint16_t ybits = xbits;
36+
do {
37+
BFloat16 y = FPBits(ybits).get_val();
38+
mpfr::BinaryInput<BFloat16> input{x, y};
39+
bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY(
40+
mpfr::Operation::Add, input, add_func(x, y), 0.5, rounding);
41+
failed += (!correct);
42+
} while (ybits++ < y_stop);
43+
} while (xbits++ < x_stop);
44+
return failed;
45+
}
46+
};
47+
48+
using LlvmLibcBfloat16ExhaustiveAddTest =
49+
LlvmLibcExhaustiveMathTest<Bfloat16AddChecker, 1 << 2>;
50+
51+
// range: [0, inf]
52+
static constexpr uint16_t POS_START = 0x0000U;
53+
static constexpr uint16_t POS_STOP = 0x7f80U;
54+
55+
// range: [-0, -inf]
56+
static constexpr uint16_t NEG_START = 0x8000U;
57+
static constexpr uint16_t NEG_STOP = 0xff80U;
58+
59+
TEST_F(LlvmLibcBfloat16ExhaustiveAddTest, PositiveRange) {
60+
test_full_range_all_roundings(POS_START, POS_STOP, POS_START, POS_STOP);
61+
}
62+
63+
TEST_F(LlvmLibcBfloat16ExhaustiveAddTest, NegativeRange) {
64+
test_full_range_all_roundings(NEG_START, NEG_STOP, NEG_START, NEG_STOP);
65+
}

0 commit comments

Comments
 (0)