121 changes: 114 additions & 7 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,69 @@ namespace __llvm_libc {
namespace testing {
namespace mpfr {

template <typename T> struct Precision;

template <> struct Precision<float> {
static constexpr unsigned int value = 24;
};

template <> struct Precision<double> {
static constexpr unsigned int value = 53;
};

#if !(defined(__x86_64__) || defined(__i386__))
template <> struct Precision<long double> {
static constexpr unsigned int value = 64;
};
#else
template <> struct Precision<long double> {
static constexpr unsigned int value = 113;
};
#endif

class MPFRNumber {
// A precision value which allows sufficiently large additional
// precision even compared to quad-precision floating point values.
static constexpr unsigned int mpfrPrecision = 128;
unsigned int mpfrPrecision;

mpfr_t value;

public:
MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
MPFRNumber() : mpfrPrecision(128) { mpfr_init2(value, mpfrPrecision); }

// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_flt(value, x, MPFR_RNDN);
}

template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_d(value, x, MPFR_RNDN);
}

template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_ld(value, x, MPFR_RNDN);
}

template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_sj(value, x, MPFR_RNDN);
}

MPFRNumber(const MPFRNumber &other) {
MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
mpfr_init2(value, mpfrPrecision);
mpfr_set(value, other.value, MPFR_RNDN);
}

Expand All @@ -85,6 +106,7 @@ class MPFRNumber {
}

MPFRNumber &operator=(const MPFRNumber &rhs) {
mpfrPrecision = rhs.mpfrPrecision;
mpfr_set(value, rhs.value, MPFR_RNDN);
return *this;
}
Expand Down Expand Up @@ -193,6 +215,12 @@ class MPFRNumber {
return result;
}

MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
MPFRNumber result(*this);
mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
return result;
}

std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
// plus additional bytes for the decimal point, '-' sign etc.
Expand Down Expand Up @@ -328,6 +356,22 @@ binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
}
}

template <typename InputType>
cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
// For FMA function, we just need to compare with the mpfr_fma with the same
// precision as InputType. Using higher precision as the intermediate results
// to compare might incorrectly fail due to double-rounding errors.
constexpr unsigned int prec = Precision<InputType>::value;
MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
switch (op) {
case Operation::Fma:
return inputX.fma(inputY, inputZ);
default:
__builtin_unreachable();
}
}

template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS) {
Expand Down Expand Up @@ -476,6 +520,48 @@ template void explainBinaryOperationOneOutputError<long double>(
Operation, const BinaryInput<long double> &, long double,
testutils::StreamWrapper &);

template <typename T>
void explainTernaryOperationOneOutputError(Operation op,
const TernaryInput<T> &input,
T libcResult,
testutils::StreamWrapper &OS) {
MPFRNumber mpfrX(input.x, Precision<T>::value);
MPFRNumber mpfrY(input.y, Precision<T>::value);
MPFRNumber mpfrZ(input.z, Precision<T>::value);
FPBits<T> xbits(input.x);
FPBits<T> ybits(input.y);
FPBits<T> zbits(input.z);
MPFRNumber mpfrResult =
ternaryOperationOneOutput(op, input.x, input.y, input.z);
MPFRNumber mpfrMatchValue(libcResult);

OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
<< " z: " << mpfrZ.str() << '\n';
__llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
OS);
__llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
OS);
__llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
OS);

OS << "Libc result: " << mpfrMatchValue.str() << '\n'
<< "MPFR result: " << mpfrResult.str() << '\n';
__llvm_libc::fputil::testing::describeValue(
"Libc floating point result bits: ", libcResult, OS);
__llvm_libc::fputil::testing::describeValue(
" MPFR rounded bits: ", mpfrResult.as<T>(), OS);
OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
}

template void explainTernaryOperationOneOutputError<float>(
Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
template void explainTernaryOperationOneOutputError<double>(
Operation, const TernaryInput<double> &, double,
testutils::StreamWrapper &);
template void explainTernaryOperationOneOutputError<long double>(
Operation, const TernaryInput<long double> &, long double,
testutils::StreamWrapper &);

template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
double ulpError) {
Expand Down Expand Up @@ -575,6 +661,27 @@ compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
template bool compareBinaryOperationOneOutput<long double>(
Operation, const BinaryInput<long double> &, long double, double);

template <typename T>
bool compareTernaryOperationOneOutput(Operation op,
const TernaryInput<T> &input,
T libcResult, double ulpError) {
MPFRNumber mpfrResult =
ternaryOperationOneOutput(op, input.x, input.y, input.z);
double ulp = mpfrResult.ulp(libcResult);

bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
return (ulp < ulpError) ||
((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
}

template bool
compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
float, double);
template bool compareTernaryOperationOneOutput<double>(
Operation, const TernaryInput<double> &, double, double);
template bool compareTernaryOperationOneOutput<long double>(
Operation, const TernaryInput<long double> &, long double, double);

static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
switch (mode) {
case RoundingMode::Upward:
Expand Down
24 changes: 22 additions & 2 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ enum class Operation : int {
RemQuo, // The first output, the floating point output, is the remainder.
EndBinaryOperationsTwoOutputs,

// Operations which take three floating point nubmers of the same type as
// input and produce a single floating point number of the same type as
// output.
BeginTernaryOperationsSingleOuput,
// TODO: Add operations like fma.
Fma,
EndTernaryOperationsSingleOutput,
};

Expand Down Expand Up @@ -113,6 +116,11 @@ template <typename T>
bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
T libcOutput, double t);

template <typename T>
bool compareTernaryOperationOneOutput(Operation op,
const TernaryInput<T> &input,
T libcOutput, double t);

template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
Expand All @@ -132,6 +140,12 @@ void explainBinaryOperationOneOutputError(Operation op,
T matchValue,
testutils::StreamWrapper &OS);

template <typename T>
void explainTernaryOperationOneOutputError(Operation op,
const TernaryInput<T> &input,
T matchValue,
testutils::StreamWrapper &OS);

template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
Expand Down Expand Up @@ -174,7 +188,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {

template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
// TODO: Implement the comparision function and error reporter.
return compareTernaryOperationOneOutput(op, in, out, tolerance);
}

template <typename T>
Expand All @@ -199,6 +213,12 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
testutils::StreamWrapper &OS) {
explainBinaryOperationOneOutputError(op, in, out, OS);
}

template <typename T>
static void explainError(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explainTernaryOperationOneOutputError(op, in, out, OS);
}
};

} // namespace internal
Expand Down