345 changes: 295 additions & 50 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"

#include <memory>
#include <mpfr.h>
#include <stdint.h>
#include <string>
Expand Down Expand Up @@ -65,50 +66,90 @@ class MPFRNumber {
mpfr_set_sj(value, x, MPFR_RNDN);
}

template <typename XType,
cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
MPFRNumber(Operation op, XType rawValue) {
mpfr_init2(value, mpfrPrecision);
MPFRNumber mpfrInput(rawValue);
switch (op) {
case Operation::Abs:
mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Ceil:
mpfr_ceil(value, mpfrInput.value);
break;
case Operation::Cos:
mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp:
mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp2:
mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Floor:
mpfr_floor(value, mpfrInput.value);
break;
case Operation::Round:
mpfr_round(value, mpfrInput.value);
break;
case Operation::Sin:
mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Sqrt:
mpfr_sqrt(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Trunc:
mpfr_trunc(value, mpfrInput.value);
break;
}
}

MPFRNumber(const MPFRNumber &other) {
mpfr_set(value, other.value, MPFR_RNDN);
}

~MPFRNumber() { mpfr_clear(value); }
MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfr_set(value, rhs.value, MPFR_RNDN);
return *this;
}

MPFRNumber abs() const {
MPFRNumber result;
mpfr_abs(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber ceil() const {
MPFRNumber result;
mpfr_ceil(result.value, value);
return result;
}

MPFRNumber cos() const {
MPFRNumber result;
mpfr_cos(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber exp() const {
MPFRNumber result;
mpfr_exp(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber exp2() const {
MPFRNumber result;
mpfr_exp2(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber floor() const {
MPFRNumber result;
mpfr_floor(result.value, value);
return result;
}

MPFRNumber frexp(int &exp) {
MPFRNumber result;
mpfr_exp_t resultExp;
mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
exp = resultExp;
return result;
}

MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
MPFRNumber remainder;
long q;
mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
quotient = q;
return remainder;
}

MPFRNumber round() const {
MPFRNumber result;
mpfr_round(result.value, value);
return result;
}

MPFRNumber sin() const {
MPFRNumber result;
mpfr_sin(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber sqrt() const {
MPFRNumber result;
mpfr_sqrt(result.value, value, MPFR_RNDN);
return result;
}

MPFRNumber trunc() const {
MPFRNumber result;
mpfr_trunc(result.value, value);
return result;
}

std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
Expand Down Expand Up @@ -179,10 +220,65 @@ class MPFRNumber {

namespace internal {

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unaryOperation(Operation op, InputType input) {
MPFRNumber mpfrInput(input);
switch (op) {
case Operation::Abs:
return mpfrInput.abs();
case Operation::Ceil:
return mpfrInput.ceil();
case Operation::Cos:
return mpfrInput.cos();
case Operation::Exp:
return mpfrInput.exp();
case Operation::Exp2:
return mpfrInput.exp2();
case Operation::Floor:
return mpfrInput.floor();
case Operation::Round:
return mpfrInput.round();
case Operation::Sin:
return mpfrInput.sin();
case Operation::Sqrt:
return mpfrInput.sqrt();
case Operation::Trunc:
return mpfrInput.trunc();
default:
__builtin_unreachable();
}
}

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
MPFRNumber mpfrInput(input);
switch (op) {
case Operation::Frexp:
return mpfrInput.frexp(output);
default:
__builtin_unreachable();
}
}

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
MPFRNumber inputX(x), inputY(y);
switch (op) {
case Operation::RemQuo:
return inputX.remquo(inputY, output);
default:
__builtin_unreachable();
}
}

template <typename T>
void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
MPFRNumber mpfrResult(operation, input);
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
MPFRNumber mpfrResult = unaryOperation(op, input);
MPFRNumber mpfrMatchValue(matchValue);
FPBits<T> inputBits(input);
FPBits<T> matchBits(matchValue);
Expand All @@ -201,25 +297,174 @@ void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
<< '\n';
}

template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
template void
MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
testutils::StreamWrapper &);
template void
explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
testutils::StreamWrapper &);
template void explainUnaryOperationSingleOutputError<long double>(
Operation op, long double, long double, testutils::StreamWrapper &);

template <typename T>
void explainUnaryOperationTwoOutputsError(Operation op, T input,
const BinaryOutput<T> &libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrInput(input);
FPBits<T> inputBits(input);
int mpfrIntResult;
MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);

if (mpfrIntResult != libcResult.i) {
OS << "MPFR integral result: " << mpfrIntResult << '\n'
<< "Libc integral result: " << libcResult.i << '\n';
} else {
OS << "Integral result from libc matches integral result from MPFR.\n";
}

MPFRNumber mpfrMatchValue(libcResult.f);
OS << "Libc floating point result is not within tolerance value of the MPFR "
<< "result.\n\n";

OS << " Input decimal: " << mpfrInput.str() << "\n\n";

OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
" Libc floating point bits: ", libcResult.f, OS);
OS << "\n\n";

OS << " MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded: ", mpfrResult.as<T>(), OS);
OS << '\n'
<< " ULP error: "
<< std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
}

template void explainUnaryOperationTwoOutputsError<float>(
Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
template void
explainUnaryOperationTwoOutputsError<double>(Operation, double,
const BinaryOutput<double> &,
testutils::StreamWrapper &);
template void explainUnaryOperationTwoOutputsError<long double>(
Operation, long double, const BinaryOutput<long double> &,
testutils::StreamWrapper &);

template <typename T>
bool compare(Operation op, T input, T libcResult, double ulpError) {
void explainBinaryOperationTwoOutputsError(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x);
MPFRNumber mpfrY(input.y);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
int mpfrIntResult;
MPFRNumber mpfrResult =
binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
MPFRNumber mpfrMatchValue(libcResult.f);

OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
<< "MPFR integral result: " << mpfrIntResult << '\n'
<< "Libc integral result: " << libcResult.i << '\n'
<< "Libc floating point result: " << mpfrMatchValue.str() << '\n'
<< " MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
"Libc floating point result bits: ", libcResult.f, OS);
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded bits: ", mpfrResult.as<T>(), OS);
OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
}

template void explainBinaryOperationTwoOutputsError<float>(
Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
testutils::StreamWrapper &);
template void explainBinaryOperationTwoOutputsError<double>(
Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
testutils::StreamWrapper &);
template void explainBinaryOperationTwoOutputsError<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, testutils::StreamWrapper &);

template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
double ulpError) {
// If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
// is rounded to the nearest even.
MPFRNumber mpfrResult(op, input);
MPFRNumber mpfrResult = unaryOperation(op, input);
double ulp = mpfrResult.ulp(libcResult);
bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}

template bool compare<float>(Operation, float, float, double);
template bool compare<double>(Operation, double, double, double);
template bool compare<long double>(Operation, long double, long double, double);
template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
double);
template bool compareUnaryOperationSingleOutput<double>(Operation, double,
double, double);
template bool compareUnaryOperationSingleOutput<long double>(Operation,
long double,
long double,
double);

template <typename T>
bool compareUnaryOperationTwoOutputs(Operation op, T input,
const BinaryOutput<T> &libcResult,
double ulpError) {
int mpfrIntResult;
MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
double ulp = mpfrResult.ulp(libcResult.f);

if (mpfrIntResult != libcResult.i)
return false;

bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}

template bool
compareUnaryOperationTwoOutputs<float>(Operation, float,
const BinaryOutput<float> &, double);
template bool
compareUnaryOperationTwoOutputs<double>(Operation, double,
const BinaryOutput<double> &, double);
template bool compareUnaryOperationTwoOutputs<long double>(
Operation, long double, const BinaryOutput<long double> &, double);

template <typename T>
bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libcResult,
double ulpError) {
int mpfrIntResult;
MPFRNumber mpfrResult =
binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
double ulp = mpfrResult.ulp(libcResult.f);

if (mpfrIntResult != libcResult.i) {
if (op == Operation::RemQuo) {
if ((0x7 & mpfrIntResult) != libcResult.i)
return false;
} else {
return false;
}
}

bool bitsAreEven = ((FPBits<T>(libcResult.f).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}

template bool
compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
const BinaryOutput<float> &, double);
template bool
compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
const BinaryOutput<double> &, double);
template bool compareBinaryOperationTwoOutputs<long double>(
Operation, const BinaryInput<long double> &,
const BinaryOutput<long double> &, double);

} // namespace internal

Expand Down
200 changes: 176 additions & 24 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ namespace testing {
namespace mpfr {

enum class Operation : int {
// Operations with take a single floating point number as input
// and produce a single floating point number as output. The input
// and output floating point numbers are of the same kind.
BeginUnaryOperationsSingleOutput,
Abs,
Ceil,
Cos,
Expand All @@ -28,57 +32,205 @@ enum class Operation : int {
Round,
Sin,
Sqrt,
Trunc
Trunc,
EndUnaryOperationsSingleOutput,

// Operations which take a single floating point nubmer as input
// but produce two outputs. The first ouput is a floating point
// number of the same type as the input. The second output is of type
// 'int'.
BeginUnaryOperationsTwoOutputs,
Frexp, // Floating point output, the first output, is the fractional part.
EndUnaryOperationsTwoOutputs,

// Operations wich take two floating point nubmers of the same type as
// input and produce a single floating point number of the same type as
// output.
BeginBinaryOperationsSingleOutput,
// TODO: Add operations like hypot.
EndBinaryOperationsSingleOutput,

// Operations which take two floating point numbers of the same type as
// input and produce two outputs. The first output is a floating nubmer of
// the same type as the inputs. The second output is af type 'int'.
BeginBinaryOperationsTwoOutputs,
RemQuo, // The first output, the floating point output, is the remainder.
EndBinaryOperationsTwoOutputs,

BeginTernaryOperationsSingleOuput,
// TODO: Add operations like fma.
EndTernaryOperationsSingleOutput,
};

template <typename T> struct BinaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"Template parameter of BinaryInput must be a floating point type.");

using Type = T;
T x, y;
};

template <typename T> struct TernaryInput {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"Template parameter of TernaryInput must be a floating point type.");

using Type = T;
T x, y, z;
};

template <typename T> struct BinaryOutput {
T f;
int i;
};

namespace internal {

template <typename T1, typename T2>
struct AreMatchingBinaryInputAndBinaryOutput {
static constexpr bool value = false;
};

template <typename T>
bool compare(Operation op, T input, T libcOutput, double t);
struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
};

template <typename T> class MPFRMatcher : public testing::Matcher<T> {
static_assert(__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"MPFRMatcher can only be used with floating point values.");
template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
double t);
template <typename T>
bool compareUnaryOperationTwoOutputs(Operation op, T input,
const BinaryOutput<T> &libcOutput,
double t);
template <typename T>
bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libcOutput,
double t);

Operation operation;
T input;
T matchValue;
template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
template <typename T>
void explainUnaryOperationTwoOutputsError(Operation op, T input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);
template <typename T>
void explainBinaryOperationTwoOutputsError(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);

template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
OutputType matchValue;
double ulpTolerance;

public:
MPFRMatcher(Operation op, T testInput, double ulpTolerance)
: operation(op), input(testInput), ulpTolerance(ulpTolerance) {}
MPFRMatcher(InputType testInput, double ulpTolerance)
: input(testInput), ulpTolerance(ulpTolerance) {}

bool match(T libcResult) {
bool match(OutputType libcResult) {
matchValue = libcResult;
return internal::compare(operation, input, libcResult, ulpTolerance);
return match(input, matchValue, ulpTolerance);
}

void explainError(testutils::StreamWrapper &OS) override;
void explainError(testutils::StreamWrapper &OS) override {
explainError(input, matchValue, OS);
}

private:
template <typename T> static bool match(T in, T out, double tolerance) {
return compareUnaryOperationSingleOutput(op, in, out, tolerance);
}

template <typename T>
static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
}

template <typename T>
static bool match(const BinaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
}

template <typename T>
static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
double tolerance) {
return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
}

template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
}

template <typename T>
static void explainError(T in, T out, testutils::StreamWrapper &OS) {
explainUnaryOperationSingleOutputError(op, in, out, OS);
}

template <typename T>
static void explainError(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainUnaryOperationTwoOutputsError(op, in, out, OS);
}

template <typename T>
static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainBinaryOperationTwoOutputsError(op, in, out, OS);
}
};

} // namespace internal

template <typename T, typename U>
// Return true if the input and ouput types for the operation op are valid
// types.
template <Operation op, typename InputType, typename OutputType>
constexpr bool isValidOperation() {
return (Operation::BeginUnaryOperationsSingleOutput < op &&
op < Operation::EndUnaryOperationsSingleOutput &&
cpp::IsSame<InputType, OutputType>::Value &&
cpp::IsFloatingPointType<InputType>::Value) ||
(Operation::BeginUnaryOperationsTwoOutputs < op &&
op < Operation::EndUnaryOperationsTwoOutputs &&
cpp::IsFloatingPointType<InputType>::Value &&
cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) ||
(Operation::BeginBinaryOperationsSingleOutput < op &&
op < Operation::EndBinaryOperationsSingleOutput &&
cpp::IsFloatingPointType<OutputType>::Value &&
cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) ||
(Operation::BeginBinaryOperationsTwoOutputs < op &&
op < Operation::EndBinaryOperationsTwoOutputs &&
internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
OutputType>::value) ||
(Operation::BeginTernaryOperationsSingleOuput < op &&
op < Operation::EndTernaryOperationsSingleOutput &&
cpp::IsFloatingPointType<OutputType>::Value &&
cpp::IsSame<InputType, TernaryInput<OutputType>>::Value);
}

template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
typename cpp::EnableIfType<cpp::IsSameV<U, double>, internal::MPFRMatcher<T>>
getMPFRMatcher(Operation op, T input, U t) {
static_assert(
__llvm_libc::cpp::IsFloatingPointType<T>::Value,
"getMPFRMatcher can only be used to match floating point results.");
return internal::MPFRMatcher<T>(op, input, t);
cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
internal::MPFRMatcher<op, InputType, OutputType>>
getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
}

} // namespace mpfr
} // namespace testing
} // namespace __llvm_libc

#define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \
EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
op, input, tolerance))
EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))

#define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \
ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher( \
op, input, tolerance))
ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))

#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H