474 changes: 284 additions & 190 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "utils/UnitTest/FPMatcher.h"

#include <cmath>
#include <fenv.h>
#include <memory>
#include <stdint.h>
#include <string>
Expand Down Expand Up @@ -55,141 +56,227 @@ template <> struct Precision<long double> {
};
#endif

// A precision value which allows sufficiently large additional
// precision compared to the floating point precision.
template <typename T> struct ExtraPrecision;

template <> struct ExtraPrecision<float> {
static constexpr unsigned int VALUE = 128;
};

template <> struct ExtraPrecision<double> {
static constexpr unsigned int VALUE = 256;
};

template <> struct ExtraPrecision<long double> {
static constexpr unsigned int VALUE = 256;
};

// If the ulp tolerance is less than or equal to 0.5, we would check that the
// result is rounded correctly with respect to the rounding mode by using the
// same precision as the inputs.
template <typename T>
static inline unsigned int get_precision(double ulp_tolerance) {
if (ulp_tolerance <= 0.5) {
return Precision<T>::VALUE;
} else {
return ExtraPrecision<T>::VALUE;
}
}

static inline mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:
return MPFR_RNDU;
break;
case RoundingMode::Downward:
return MPFR_RNDD;
break;
case RoundingMode::TowardZero:
return MPFR_RNDZ;
break;
case RoundingMode::Nearest:
return MPFR_RNDN;
break;
}
}

int get_fe_rounding(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:
return FE_UPWARD;
break;
case RoundingMode::Downward:
return FE_DOWNWARD;
break;
case RoundingMode::TowardZero:
return FE_TOWARDZERO;
break;
case RoundingMode::Nearest:
return FE_TONEAREST;
break;
}
}

ForceRoundingMode::ForceRoundingMode(RoundingMode mode) {
old_rounding_mode = fegetround();
rounding_mode = get_fe_rounding(mode);
if (old_rounding_mode != rounding_mode)
fesetround(rounding_mode);
}

ForceRoundingMode::~ForceRoundingMode() {
if (old_rounding_mode != rounding_mode)
fesetround(old_rounding_mode);
}

class MPFRNumber {
// A precision value which allows sufficiently large additional
// precision even compared to quad-precision floating point values.
unsigned int mpfr_precision;
mpfr_rnd_t mpfr_rounding;

mpfr_t value;

public:
MPFRNumber() : mpfr_precision(256) { mpfr_init2(value, mpfr_precision); }
MPFRNumber() : mpfr_precision(256), mpfr_rounding(MPFR_RNDN) {
mpfr_init2(value, mpfr_precision);
}

// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
explicit MPFRNumber(XType x, int precision = 128)
: mpfr_precision(precision) {
explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
RoundingMode rounding = RoundingMode::Nearest)
: mpfr_precision(precision),
mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
mpfr_set_flt(value, x, MPFR_RNDN);
mpfr_set_flt(value, x, mpfr_rounding);
}

template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x, int precision = 128)
: mpfr_precision(precision) {
explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
RoundingMode rounding = RoundingMode::Nearest)
: mpfr_precision(precision),
mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
mpfr_set_d(value, x, MPFR_RNDN);
mpfr_set_d(value, x, mpfr_rounding);
}

template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x, int precision = 128)
: mpfr_precision(precision) {
explicit MPFRNumber(XType x, int precision = ExtraPrecision<XType>::VALUE,
RoundingMode rounding = RoundingMode::Nearest)
: mpfr_precision(precision),
mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
mpfr_set_ld(value, x, MPFR_RNDN);
mpfr_set_ld(value, x, mpfr_rounding);
}

template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
explicit MPFRNumber(XType x, int precision = 128)
: mpfr_precision(precision) {
explicit MPFRNumber(XType x, int precision = ExtraPrecision<float>::VALUE,
RoundingMode rounding = RoundingMode::Nearest)
: mpfr_precision(precision),
mpfr_rounding(get_mpfr_rounding_mode(rounding)) {
mpfr_init2(value, mpfr_precision);
mpfr_set_sj(value, x, MPFR_RNDN);
mpfr_set_sj(value, x, mpfr_rounding);
}

MPFRNumber(const MPFRNumber &other) : mpfr_precision(other.mpfr_precision) {
MPFRNumber(const MPFRNumber &other)
: mpfr_precision(other.mpfr_precision),
mpfr_rounding(other.mpfr_rounding) {
mpfr_init2(value, mpfr_precision);
mpfr_set(value, other.value, MPFR_RNDN);
mpfr_set(value, other.value, mpfr_rounding);
}

~MPFRNumber() { mpfr_clear(value); }

MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfr_precision = rhs.mpfr_precision;
mpfr_set(value, rhs.value, MPFR_RNDN);
mpfr_rounding = rhs.mpfr_rounding;
mpfr_set(value, rhs.value, mpfr_rounding);
return *this;
}

MPFRNumber abs() const {
MPFRNumber result;
mpfr_abs(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_abs(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber ceil() const {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_ceil(result.value, value);
return result;
}

MPFRNumber cos() const {
MPFRNumber result;
mpfr_cos(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_cos(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber exp() const {
MPFRNumber result;
mpfr_exp(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_exp(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber exp2() const {
MPFRNumber result;
mpfr_exp2(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_exp2(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber expm1() const {
MPFRNumber result;
mpfr_expm1(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_expm1(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber floor() const {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_floor(result.value, value);
return result;
}

MPFRNumber frexp(int &exp) {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_exp_t resultExp;
mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
mpfr_frexp(&resultExp, result.value, value, mpfr_rounding);
exp = resultExp;
return result;
}

MPFRNumber hypot(const MPFRNumber &b) {
MPFRNumber result;
mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_hypot(result.value, value, b.value, mpfr_rounding);
return result;
}

MPFRNumber log() const {
MPFRNumber result;
mpfr_log(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_log(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
MPFRNumber remainder;
MPFRNumber remainder(*this);
long q;
mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
mpfr_remquo(remainder.value, &q, value, divisor.value, mpfr_rounding);
quotient = q;
return remainder;
}

MPFRNumber round() const {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_round(result.value, value);
return result;
}

bool roung_to_long(long &result) const {
bool round_to_long(long &result) const {
// We first calculate the rounded value. This way, when converting
// to long using mpfr_get_si, the rounding direction of MPFR_RNDN
// (or any other rounding mode), does not have an influence.
Expand All @@ -199,14 +286,14 @@ class MPFRNumber {
return mpfr_erangeflag_p();
}

bool roung_to_long(mpfr_rnd_t rnd, long &result) const {
MPFRNumber rint_result;
bool round_to_long(mpfr_rnd_t rnd, long &result) const {
MPFRNumber rint_result(*this);
mpfr_rint(rint_result.value, value, rnd);
return rint_result.roung_to_long(result);
return rint_result.round_to_long(result);
}

MPFRNumber rint(mpfr_rnd_t rnd) const {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_rint(result.value, value, rnd);
return result;
}
Expand Down Expand Up @@ -239,32 +326,32 @@ class MPFRNumber {
}

MPFRNumber sin() const {
MPFRNumber result;
mpfr_sin(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_sin(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber sqrt() const {
MPFRNumber result;
mpfr_sqrt(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_sqrt(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber tan() const {
MPFRNumber result;
mpfr_tan(result.value, value, MPFR_RNDN);
MPFRNumber result(*this);
mpfr_tan(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber trunc() const {
MPFRNumber result;
MPFRNumber result(*this);
mpfr_trunc(result.value, value);
return result;
}

MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
MPFRNumber result(*this);
mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
mpfr_fma(result.value, value, b.value, c.value, mpfr_rounding);
return result;
}

Expand All @@ -282,10 +369,14 @@ class MPFRNumber {
// These functions are useful for debugging.
template <typename T> T as() const;

template <> float as<float>() const { return mpfr_get_flt(value, MPFR_RNDN); }
template <> double as<double>() const { return mpfr_get_d(value, MPFR_RNDN); }
template <> float as<float>() const {
return mpfr_get_flt(value, mpfr_rounding);
}
template <> double as<double>() const {
return mpfr_get_d(value, mpfr_rounding);
}
template <> long double as<long double>() const {
return mpfr_get_ld(value, MPFR_RNDN);
return mpfr_get_ld(value, mpfr_rounding);
}

void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
Expand Down Expand Up @@ -378,8 +469,9 @@ namespace internal {

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unary_operation(Operation op, InputType input) {
MPFRNumber mpfrInput(input);
unary_operation(Operation op, InputType input, unsigned int precision,
RoundingMode rounding) {
MPFRNumber mpfrInput(input, precision, rounding);
switch (op) {
case Operation::Abs:
return mpfrInput.abs();
Expand Down Expand Up @@ -420,8 +512,9 @@ unary_operation(Operation op, InputType input) {

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unary_operation_two_outputs(Operation op, InputType input, int &output) {
MPFRNumber mpfrInput(input);
unary_operation_two_outputs(Operation op, InputType input, int &output,
unsigned int precision, RoundingMode rounding) {
MPFRNumber mpfrInput(input, precision, rounding);
switch (op) {
case Operation::Frexp:
return mpfrInput.frexp(output);
Expand All @@ -432,8 +525,10 @@ unary_operation_two_outputs(Operation op, InputType input, int &output) {

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
binary_operation_one_output(Operation op, InputType x, InputType y) {
MPFRNumber inputX(x), inputY(y);
binary_operation_one_output(Operation op, InputType x, InputType y,
unsigned int precision, RoundingMode rounding) {
MPFRNumber inputX(x, precision, rounding);
MPFRNumber inputY(y, precision, rounding);
switch (op) {
case Operation::Hypot:
return inputX.hypot(inputY);
Expand All @@ -445,8 +540,10 @@ binary_operation_one_output(Operation op, InputType x, InputType y) {
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
binary_operation_two_outputs(Operation op, InputType x, InputType y,
int &output) {
MPFRNumber inputX(x), inputY(y);
int &output, unsigned int precision,
RoundingMode rounding) {
MPFRNumber inputX(x, precision, rounding);
MPFRNumber inputY(y, precision, rounding);
switch (op) {
case Operation::RemQuo:
return inputX.remquo(inputY, output);
Expand All @@ -458,12 +555,14 @@ binary_operation_two_outputs(Operation op, InputType x, InputType y,
template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
ternary_operation_one_output(Operation op, InputType x, InputType y,
InputType z) {
InputType z, unsigned int precision,
RoundingMode rounding) {
// For FMA function, we just need to compare with the mpfr_fma with the same
// precision as InputType. Using higher precision as the intermediate results
// to compare might incorrectly fail due to double-rounding errors.
constexpr unsigned int prec = Precision<InputType>::VALUE;
MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
MPFRNumber inputX(x, precision, rounding);
MPFRNumber inputY(y, precision, rounding);
MPFRNumber inputZ(z, precision, rounding);
switch (op) {
case Operation::Fma:
return inputX.fma(inputY, inputZ);
Expand All @@ -475,13 +574,14 @@ ternary_operation_one_output(Operation op, InputType x, InputType y,
template <typename T>
void explain_unary_operation_single_output_error(Operation op, T input,
T matchValue,
double ulp_tolerance,
RoundingMode rounding,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
MPFRNumber mpfr_result = unary_operation(op, input);
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfrInput(input, precision);
MPFRNumber mpfr_result;
mpfr_result = unary_operation(op, input, precision, rounding);
MPFRNumber mpfrMatchValue(matchValue);
FPBits<T> inputBits(input);
FPBits<T> matchBits(matchValue);
FPBits<T> mpfr_resultBits(mpfr_result.as<T>());
OS << "Match value not within tolerance value of MPFR result:\n"
<< " Input decimal: " << mpfrInput.str() << '\n';
__llvm_libc::fputil::testing::describeValue(" Input bits: ", input, OS);
Expand All @@ -498,21 +598,24 @@ void explain_unary_operation_single_output_error(Operation op, T input,

template void
explain_unary_operation_single_output_error<float>(Operation op, float, float,
double, RoundingMode,
testutils::StreamWrapper &);
template void explain_unary_operation_single_output_error<double>(
Operation op, double, double, testutils::StreamWrapper &);
Operation op, double, double, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_unary_operation_single_output_error<long double>(
Operation op, long double, long double, testutils::StreamWrapper &);
Operation op, long double, long double, double, RoundingMode,
testutils::StreamWrapper &);

template <typename T>
void explain_unary_operation_two_outputs_error(
Operation op, T input, const BinaryOutput<T> &libc_result,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
FPBits<T> inputBits(input);
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfrInput(input, precision);
int mpfrIntResult;
MPFRNumber mpfr_result =
unary_operation_two_outputs(op, input, mpfrIntResult);
MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
precision, rounding);

if (mpfrIntResult != libc_result.i) {
OS << "MPFR integral result: " << mpfrIntResult << '\n'
Expand Down Expand Up @@ -541,26 +644,26 @@ void explain_unary_operation_two_outputs_error(
}

template void explain_unary_operation_two_outputs_error<float>(
Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
template void
explain_unary_operation_two_outputs_error<double>(Operation, double,
const BinaryOutput<double> &,
testutils::StreamWrapper &);
template void explain_unary_operation_two_outputs_error<long double>(
Operation, long double, const BinaryOutput<long double> &,
Operation, float, const BinaryOutput<float> &, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_unary_operation_two_outputs_error<double>(
Operation, double, const BinaryOutput<double> &, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_unary_operation_two_outputs_error<long double>(
Operation, long double, const BinaryOutput<long double> &, double,
RoundingMode, testutils::StreamWrapper &);

template <typename T>
void explain_binary_operation_two_outputs_error(
Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libc_result, testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x);
MPFRNumber mpfrY(input.y);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
const BinaryOutput<T> &libc_result, double ulp_tolerance,
RoundingMode rounding, testutils::StreamWrapper &OS) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfrX(input.x, precision);
MPFRNumber mpfrY(input.y, precision);
int mpfrIntResult;
MPFRNumber mpfr_result =
binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
MPFRNumber mpfr_result = binary_operation_two_outputs(
op, input.x, input.y, mpfrIntResult, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result.f);

OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
Expand All @@ -576,25 +679,27 @@ void explain_binary_operation_two_outputs_error(
}

template void explain_binary_operation_two_outputs_error<float>(
Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
testutils::StreamWrapper &);
Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
RoundingMode, testutils::StreamWrapper &);
template void explain_binary_operation_two_outputs_error<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
testutils::StreamWrapper &);
double, RoundingMode, testutils::StreamWrapper &);
template void explain_binary_operation_two_outputs_error<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, testutils::StreamWrapper &);
const BinaryOutput<long double> &, double, RoundingMode,
testutils::StreamWrapper &);

template <typename T>
void explain_binary_operation_one_output_error(Operation op,
const BinaryInput<T> &input,
T libc_result,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x);
MPFRNumber mpfrY(input.y);
void explain_binary_operation_one_output_error(
Operation op, const BinaryInput<T> &input, T libc_result,
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfrX(input.x, precision);
MPFRNumber mpfrY(input.y, precision);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
MPFRNumber mpfr_result =
binary_operation_one_output(op, input.x, input.y, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result);

OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
Expand All @@ -613,26 +718,28 @@ void explain_binary_operation_one_output_error(Operation op,
}

template void explain_binary_operation_one_output_error<float>(
Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
Operation, const BinaryInput<float> &, float, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_binary_operation_one_output_error<double>(
Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
template void explain_binary_operation_one_output_error<long double>(
Operation, const BinaryInput<long double> &, long double,
Operation, const BinaryInput<double> &, double, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_binary_operation_one_output_error<long double>(
Operation, const BinaryInput<long double> &, long double, double,
RoundingMode, testutils::StreamWrapper &);

template <typename T>
void explain_ternary_operation_one_output_error(Operation op,
const TernaryInput<T> &input,
T libc_result,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x, Precision<T>::VALUE);
MPFRNumber mpfrY(input.y, Precision<T>::VALUE);
MPFRNumber mpfrZ(input.z, Precision<T>::VALUE);
void explain_ternary_operation_one_output_error(
Operation op, const TernaryInput<T> &input, T libc_result,
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfrX(input.x, precision);
MPFRNumber mpfrY(input.y, precision);
MPFRNumber mpfrZ(input.z, precision);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
FPBits<T> zbits(input.z);
MPFRNumber mpfr_result =
ternary_operation_one_output(op, input.x, input.y, input.z);
MPFRNumber mpfr_result = ternary_operation_one_output(
op, input.x, input.y, input.z, precision, rounding);
MPFRNumber mpfrMatchValue(libc_result);

OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
Expand All @@ -654,68 +761,70 @@ void explain_ternary_operation_one_output_error(Operation op,
}

template void explain_ternary_operation_one_output_error<float>(
Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
Operation, const TernaryInput<float> &, float, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_ternary_operation_one_output_error<double>(
Operation, const TernaryInput<double> &, double,
Operation, const TernaryInput<double> &, double, double, RoundingMode,
testutils::StreamWrapper &);
template void explain_ternary_operation_one_output_error<long double>(
Operation, const TernaryInput<long double> &, long double,
testutils::StreamWrapper &);
Operation, const TernaryInput<long double> &, long double, double,
RoundingMode, testutils::StreamWrapper &);

template <typename T>
bool compare_unary_operation_single_output(Operation op, T input, T libc_result,
double ulp_error) {
// If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
// is rounded to the nearest even.
MPFRNumber mpfr_result = unary_operation(op, input);
double ulp_tolerance,
RoundingMode rounding) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfr_result;
mpfr_result = unary_operation(op, input, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);
bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
return (ulp < ulp_error) ||
((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
return (ulp <= ulp_tolerance);
}

template bool compare_unary_operation_single_output<float>(Operation, float,
float, double);
float, double,
RoundingMode);
template bool compare_unary_operation_single_output<double>(Operation, double,
double, double);
template bool compare_unary_operation_single_output<long double>(Operation,
long double,
long double,
double);
double, double,
RoundingMode);
template bool compare_unary_operation_single_output<long double>(
Operation, long double, long double, double, RoundingMode);

template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
const BinaryOutput<T> &libc_result,
double ulp_error) {
double ulp_tolerance,
RoundingMode rounding) {
int mpfrIntResult;
MPFRNumber mpfr_result =
unary_operation_two_outputs(op, input, mpfrIntResult);
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfr_result = unary_operation_two_outputs(op, input, mpfrIntResult,
precision, rounding);
double ulp = mpfr_result.ulp(libc_result.f);

if (mpfrIntResult != libc_result.i)
return false;

bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
return (ulp < ulp_error) ||
((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
return (ulp <= ulp_tolerance);
}

template bool
compare_unary_operation_two_outputs<float>(Operation, float,
const BinaryOutput<float> &, double);
template bool compare_unary_operation_two_outputs<float>(
Operation, float, const BinaryOutput<float> &, double, RoundingMode);
template bool compare_unary_operation_two_outputs<double>(
Operation, double, const BinaryOutput<double> &, double);
Operation, double, const BinaryOutput<double> &, double, RoundingMode);
template bool compare_unary_operation_two_outputs<long double>(
Operation, long double, const BinaryOutput<long double> &, double);
Operation, long double, const BinaryOutput<long double> &, double,
RoundingMode);

template <typename T>
bool compare_binary_operation_two_outputs(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libc_result,
double ulp_error) {
double ulp_tolerance,
RoundingMode rounding) {
int mpfrIntResult;
MPFRNumber mpfr_result =
binary_operation_two_outputs(op, input.x, input.y, mpfrIntResult);
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfr_result = binary_operation_two_outputs(
op, input.x, input.y, mpfrIntResult, precision, rounding);
double ulp = mpfr_result.ulp(libc_result.f);

if (mpfrIntResult != libc_result.i) {
Expand All @@ -727,81 +836,66 @@ bool compare_binary_operation_two_outputs(Operation op,
}
}

bool bits_are_even = ((FPBits<T>(libc_result.f).uintval() & 1) == 0);
return (ulp < ulp_error) ||
((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
return (ulp <= ulp_tolerance);
}

template bool compare_binary_operation_two_outputs<float>(
Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double);
Operation, const BinaryInput<float> &, const BinaryOutput<float> &, double,
RoundingMode);
template bool compare_binary_operation_two_outputs<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
double);
double, RoundingMode);
template bool compare_binary_operation_two_outputs<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, double);
const BinaryOutput<long double> &, double, RoundingMode);

template <typename T>
bool compare_binary_operation_one_output(Operation op,
const BinaryInput<T> &input,
T libc_result, double ulp_error) {
MPFRNumber mpfr_result = binary_operation_one_output(op, input.x, input.y);
T libc_result, double ulp_tolerance,
RoundingMode rounding) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfr_result =
binary_operation_one_output(op, input.x, input.y, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);

bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
return (ulp < ulp_error) ||
((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
return (ulp <= ulp_tolerance);
}

template bool compare_binary_operation_one_output<float>(
Operation, const BinaryInput<float> &, float, double);
Operation, const BinaryInput<float> &, float, double, RoundingMode);
template bool compare_binary_operation_one_output<double>(
Operation, const BinaryInput<double> &, double, double);
Operation, const BinaryInput<double> &, double, double, RoundingMode);
template bool compare_binary_operation_one_output<long double>(
Operation, const BinaryInput<long double> &, long double, double);
Operation, const BinaryInput<long double> &, long double, double,
RoundingMode);

template <typename T>
bool compare_ternary_operation_one_output(Operation op,
const TernaryInput<T> &input,
T libc_result, double ulp_error) {
MPFRNumber mpfr_result =
ternary_operation_one_output(op, input.x, input.y, input.z);
T libc_result, double ulp_tolerance,
RoundingMode rounding) {
unsigned int precision = get_precision<T>(ulp_tolerance);
MPFRNumber mpfr_result = ternary_operation_one_output(
op, input.x, input.y, input.z, precision, rounding);
double ulp = mpfr_result.ulp(libc_result);

bool bits_are_even = ((FPBits<T>(libc_result).uintval() & 1) == 0);
return (ulp < ulp_error) ||
((ulp == ulp_error) && ((ulp != 0.5) || bits_are_even));
return (ulp <= ulp_tolerance);
}

template bool compare_ternary_operation_one_output<float>(
Operation, const TernaryInput<float> &, float, double);
Operation, const TernaryInput<float> &, float, double, RoundingMode);
template bool compare_ternary_operation_one_output<double>(
Operation, const TernaryInput<double> &, double, double);
Operation, const TernaryInput<double> &, double, double, RoundingMode);
template bool compare_ternary_operation_one_output<long double>(
Operation, const TernaryInput<long double> &, long double, double);

static mpfr_rnd_t get_mpfr_rounding_mode(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:
return MPFR_RNDU;
break;
case RoundingMode::Downward:
return MPFR_RNDD;
break;
case RoundingMode::TowardZero:
return MPFR_RNDZ;
break;
case RoundingMode::Nearest:
return MPFR_RNDN;
break;
}
}
Operation, const TernaryInput<long double> &, long double, double,
RoundingMode);

} // namespace internal

template <typename T> bool round_to_long(T x, long &result) {
MPFRNumber mpfr(x);
return mpfr.roung_to_long(result);
return mpfr.round_to_long(result);
}

template bool round_to_long<float>(float, long &);
Expand All @@ -810,7 +904,7 @@ template bool round_to_long<long double>(long double, long &);

template <typename T> bool round_to_long(T x, RoundingMode mode, long &result) {
MPFRNumber mpfr(x);
return mpfr.roung_to_long(internal::get_mpfr_rounding_mode(mode), result);
return mpfr.round_to_long(get_mpfr_rounding_mode(mode), result);
}

template bool round_to_long<float>(float, RoundingMode, long &);
Expand All @@ -819,7 +913,7 @@ template bool round_to_long<long double>(long double, RoundingMode, long &);

template <typename T> T round(T x, RoundingMode mode) {
MPFRNumber mpfr(x);
MPFRNumber result = mpfr.rint(internal::get_mpfr_rounding_mode(mode));
MPFRNumber result = mpfr.rint(get_mpfr_rounding_mode(mode));
return result.as<T>();
}

Expand Down
164 changes: 109 additions & 55 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ enum class Operation : int {
EndTernaryOperationsSingleOutput,
};

enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };

int get_fe_rounding(RoundingMode mode);

struct ForceRoundingMode {
ForceRoundingMode(RoundingMode);
~ForceRoundingMode();

int old_rounding_mode;
int rounding_mode;
};

template <typename T> struct BinaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
Expand Down Expand Up @@ -108,65 +120,72 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {

template <typename T>
bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
double t);
double ulp_tolerance,
RoundingMode rounding);
template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
const BinaryOutput<T> &libc_output,
double t);
double ulp_tolerance,
RoundingMode rounding);
template <typename T>
bool compare_binary_operation_two_outputs(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libc_output,
double t);
double ulp_tolerance,
RoundingMode rounding);

template <typename T>
bool compare_binary_operation_one_output(Operation op,
const BinaryInput<T> &input,
T libc_output, double t);
T libc_output, double ulp_tolerance,
RoundingMode rounding);

template <typename T>
bool compare_ternary_operation_one_output(Operation op,
const TernaryInput<T> &input,
T libc_output, double t);
T libc_output, double ulp_tolerance,
RoundingMode rounding);

template <typename T>
void explain_unary_operation_single_output_error(Operation op, T input,
T match_value,
double ulp_tolerance,
RoundingMode rounding,
testutils::StreamWrapper &OS);
template <typename T>
void explain_unary_operation_two_outputs_error(
Operation op, T input, const BinaryOutput<T> &match_value,
testutils::StreamWrapper &OS);
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);
template <typename T>
void explain_binary_operation_two_outputs_error(
Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &match_value, testutils::StreamWrapper &OS);
const BinaryOutput<T> &match_value, double ulp_tolerance,
RoundingMode rounding, testutils::StreamWrapper &OS);

template <typename T>
void explain_binary_operation_one_output_error(Operation op,
const BinaryInput<T> &input,
T match_value,
testutils::StreamWrapper &OS);
void explain_binary_operation_one_output_error(
Operation op, const BinaryInput<T> &input, T match_value,
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);

template <typename T>
void explain_ternary_operation_one_output_error(Operation op,
const TernaryInput<T> &input,
T match_value,
testutils::StreamWrapper &OS);
void explain_ternary_operation_one_output_error(
Operation op, const TernaryInput<T> &input, T match_value,
double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS);

template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
OutputType match_value;
double ulp_tolerance;
RoundingMode rounding;

public:
MPFRMatcher(InputType testInput, double ulp_tolerance)
: input(testInput), ulp_tolerance(ulp_tolerance) {}
MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
: input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}

bool match(OutputType libcResult) {
match_value = libcResult;
return match(input, match_value, ulp_tolerance);
return match(input, match_value);
}

// This method is marked with NOLINT because it the name `explainError`
Expand All @@ -176,59 +195,64 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
}

private:
template <typename T> static bool match(T in, T out, double tolerance) {
return compare_unary_operation_single_output(op, in, out, tolerance);
template <typename T> bool match(T in, T out) {
return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
rounding);
}

template <typename T>
static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
return compare_unary_operation_two_outputs(op, in, out, tolerance);
template <typename T> bool match(T in, const BinaryOutput<T> &out) {
return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
rounding);
}

template <typename T>
static bool match(const BinaryInput<T> &in, T out, double tolerance) {
return compare_binary_operation_one_output(op, in, out, tolerance);
template <typename T> bool match(const BinaryInput<T> &in, T out) {
return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
rounding);
}

template <typename T>
static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
double tolerance) {
return compare_binary_operation_two_outputs(op, in, out, tolerance);
bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
rounding);
}

template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
return compare_ternary_operation_one_output(op, in, out, tolerance);
template <typename T> bool match(const TernaryInput<T> &in, T out) {
return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
rounding);
}

template <typename T>
static void explain_error(T in, T out, testutils::StreamWrapper &OS) {
explain_unary_operation_single_output_error(op, in, out, OS);
void explain_error(T in, T out, testutils::StreamWrapper &OS) {
explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
rounding, OS);
}

template <typename T>
static void explain_error(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_unary_operation_two_outputs_error(op, in, out, OS);
void explain_error(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
rounding, OS);
}

template <typename T>
static void explain_error(const BinaryInput<T> &in,
const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_binary_operation_two_outputs_error(op, in, out, OS);
void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
rounding, OS);
}

template <typename T>
static void explain_error(const BinaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_binary_operation_one_output_error(op, in, out, OS);
void explain_error(const BinaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
rounding, OS);
}

template <typename T>
static void explain_error(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_ternary_operation_one_output_error(op, in, out, OS);
void explain_error(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
rounding, OS);
}
};

Expand Down Expand Up @@ -264,12 +288,12 @@ template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
internal::MPFRMatcher<op, InputType, OutputType>>
get_mpfr_matcher(InputType input, OutputType output_unused, double t) {
return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
get_mpfr_matcher(InputType input, OutputType output_unused,
double ulp_tolerance, RoundingMode rounding) {
return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance,
rounding);
}

enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };

template <typename T> T round(T x, RoundingMode mode);

template <typename T> bool round_to_long(T x, long &result);
Expand All @@ -279,12 +303,42 @@ template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);
} // namespace testing
} // namespace __llvm_libc

#define EXPECT_MPFR_MATCH(op, input, match_value, tolerance) \
// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
#define GET_MPFR_DUMMY_ARG(...) 0

#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME

#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
EXPECT_THAT(match_value, \
__llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, ulp_tolerance, \
__llvm_libc::testing::mpfr::RoundingMode::Nearest))

#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
rounding) \
EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, tolerance))
input, match_value, ulp_tolerance, rounding))

#define EXPECT_MPFR_MATCH(...) \
GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \
EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
(__VA_ARGS__)

#define ASSERT_MPFR_MATCH(op, input, match_value, tolerance) \
#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \
ASSERT_THAT(match_value, \
__llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, ulp_tolerance, \
__llvm_libc::testing::mpfr::RoundingMode::Nearest))

#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \
rounding) \
ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, tolerance))
input, match_value, ulp_tolerance, rounding))

#define ASSERT_MPFR_MATCH(...) \
GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \
ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \
(__VA_ARGS__)

#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H