Skip to content

Commit

Permalink
[libc] Implement correct rounding with all rounding modes for hypot f…
Browse files Browse the repository at this point in the history
…unctions.

Update the rounding logic for generic hypot function so that it will round correctly with all rounding modes.

Reviewed By: sivachandra, zimmermann6

Differential Revision: https://reviews.llvm.org/D117590
  • Loading branch information
lntue committed Jan 20, 2022
1 parent af56004 commit aad0453
Show file tree
Hide file tree
Showing 7 changed files with 1,289 additions and 10 deletions.
26 changes: 23 additions & 3 deletions libc/src/__support/FPUtil/Hypot.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_HYPOT_H

#include "BasicOperations.h"
#include "FEnvImpl.h"
#include "FPBits.h"
#include "src/__support/CPP/TypeTraits.h"

Expand Down Expand Up @@ -143,11 +144,22 @@ static inline T hypot(T x, T y) {
if ((x_bits.get_unbiased_exponent() >=
y_bits.get_unbiased_exponent() + MantissaWidth<T>::VALUE + 2) ||
(y == 0)) {
// Check if the rounding mode is FE_UPWARD, will need -frounding-math so
// that the compiler does not optimize it away.
if ((y != 0) && (0x1p0f + 0x1p-24f != 0x1p0f)) {
UIntType out_bits = FPBits_t(abs(x)).uintval();
return T(FPBits_t(++out_bits));
}
return abs(x);
} else if ((y_bits.get_unbiased_exponent() >=
x_bits.get_unbiased_exponent() + MantissaWidth<T>::VALUE + 2) ||
(x == 0)) {
y_bits.set_sign(0);
// Check if the rounding mode is FE_UPWARD, will need -frounding-math so
// that the compiler does not optimize it away.
if ((x != 0) && (0x1p0f + 0x1p-24f != 0x1p0f)) {
UIntType out_bits = FPBits_t(abs(y)).uintval();
return T(FPBits_t(++out_bits));
}
return abs(y);
}

Expand Down Expand Up @@ -250,8 +262,16 @@ static inline T hypot(T x, T y) {
y_new >>= 1;

// Round to the nearest, tie to even.
if (round_bit && (lsb || sticky_bits || (r != 0))) {
++y_new;
switch (get_round()) {
case FE_TONEAREST:
// Round to nearest, ties to even
if (round_bit && (lsb || sticky_bits || (r != 0)))
++y_new;
break;
case FE_UPWARD:
if (round_bit || sticky_bits || (r != 0))
++y_new;
break;
}

if (y_new >= (ONE >> 1)) {
Expand Down
8 changes: 6 additions & 2 deletions libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,9 @@ add_entrypoint_object(
DEPENDS
libc.src.__support.FPUtil.fputil
COMPILE_OPTIONS
-O2
-O3
-frounding-math
-Wno-c++17-extensions
)

add_entrypoint_object(
Expand Down Expand Up @@ -1002,7 +1004,9 @@ add_entrypoint_object(
DEPENDS
libc.src.__support.FPUtil.fputil
COMPILE_OPTIONS
-O2
-O3
-frounding-math
-Wno-c++17-extensions
)

add_entrypoint_object(
Expand Down
4 changes: 4 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,8 @@ add_fp_unittest(
libc.include.math
libc.src.math.hypotf
libc.src.__support.FPUtil.fputil
COMPILE_OPTIONS
-Wno-c++17-extensions
)

add_fp_unittest(
Expand All @@ -1073,6 +1075,8 @@ add_fp_unittest(
libc.include.math
libc.src.math.hypot
libc.src.__support.FPUtil.fputil
COMPILE_OPTIONS
-Wno-c++17-extensions
)

add_fp_unittest(
Expand Down
16 changes: 11 additions & 5 deletions libc/test/src/math/HypotTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#define LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H

#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/Hypot.h"
#include "utils/MPFRWrapper/MPFRUtils.h"
#include "utils/UnitTest/FPMatcher.h"
#include "utils/UnitTest/Test.h"
Expand Down Expand Up @@ -62,9 +61,9 @@ class HypotTestTemplate : public __llvm_libc::testing::Test {
y = -y;
}

T result = func(x, y);
mpfr::BinaryInput<T> input{x, y};
ASSERT_MPFR_MATCH(mpfr::Operation::Hypot, input, result, 0.5);
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Hypot, input,
func(x, y), 0.5);
}
}
}
Expand All @@ -85,12 +84,19 @@ class HypotTestTemplate : public __llvm_libc::testing::Test {
y = -y;
}

T result = func(x, y);
mpfr::BinaryInput<T> input{x, y};
ASSERT_MPFR_MATCH(mpfr::Operation::Hypot, input, result, 0.5);
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Hypot, input,
func(x, y), 0.5);
}
}
}

void test_input_list(Func func, int n, const mpfr::BinaryInput<T> *inputs) {
for (int i = 0; i < n; ++i) {
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Hypot, inputs[i],
func(inputs[i].x, inputs[i].y), 0.5);
}
}
};

#endif // LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H
2 changes: 2 additions & 0 deletions libc/test/src/math/differential_testing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ add_diff_binary(
libc.src.math.hypotf
COMPILE_OPTIONS
-fno-builtin
-Wno-c++17-extensions
)

add_diff_binary(
Expand All @@ -417,4 +418,5 @@ add_diff_binary(
libc.src.math.hypot
COMPILE_OPTIONS
-fno-builtin
-Wno-c++17-extensions
)
Loading

0 comments on commit aad0453

Please sign in to comment.