392 changes: 198 additions & 194 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp

Large diffs are not rendered by default.

149 changes: 77 additions & 72 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,132 +98,137 @@ namespace internal {

template <typename T1, typename T2>
struct AreMatchingBinaryInputAndBinaryOutput {
static constexpr bool value = false;
static constexpr bool VALUE = false;
};

template <typename T>
struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
static constexpr bool value = cpp::IsFloatingPointType<T>::Value;
static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value;
};

template <typename T>
bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput,
double t);
bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
double t);
template <typename T>
bool compareUnaryOperationTwoOutputs(Operation op, T input,
const BinaryOutput<T> &libcOutput,
double t);
bool compare_unary_operation_two_outputs(Operation op, T input,
const BinaryOutput<T> &libc_output,
double t);
template <typename T>
bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &libcOutput,
double t);
bool compare_binary_operation_two_outputs(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &libc_output,
double t);

template <typename T>
bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
T libcOutput, double t);
bool compare_binary_operation_one_output(Operation op,
const BinaryInput<T> &input,
T libc_output, double t);

template <typename T>
bool compareTernaryOperationOneOutput(Operation op,
const TernaryInput<T> &input,
T libcOutput, double t);
bool compare_ternary_operation_one_output(Operation op,
const TernaryInput<T> &input,
T libc_output, double t);

template <typename T>
void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
testutils::StreamWrapper &OS);
void explain_unary_operation_single_output_error(Operation op, T input,
T match_value,
testutils::StreamWrapper &OS);
template <typename T>
void explainUnaryOperationTwoOutputsError(Operation op, T input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);
void explain_unary_operation_two_outputs_error(
Operation op, T input, const BinaryOutput<T> &match_value,
testutils::StreamWrapper &OS);
template <typename T>
void explainBinaryOperationTwoOutputsError(Operation op,
const BinaryInput<T> &input,
const BinaryOutput<T> &matchValue,
testutils::StreamWrapper &OS);
void explain_binary_operation_two_outputs_error(
Operation op, const BinaryInput<T> &input,
const BinaryOutput<T> &match_value, testutils::StreamWrapper &OS);

template <typename T>
void explainBinaryOperationOneOutputError(Operation op,
const BinaryInput<T> &input,
T matchValue,
testutils::StreamWrapper &OS);
void explain_binary_operation_one_output_error(Operation op,
const BinaryInput<T> &input,
T match_value,
testutils::StreamWrapper &OS);

template <typename T>
void explainTernaryOperationOneOutputError(Operation op,
const TernaryInput<T> &input,
T matchValue,
testutils::StreamWrapper &OS);
void explain_ternary_operation_one_output_error(Operation op,
const TernaryInput<T> &input,
T match_value,
testutils::StreamWrapper &OS);

template <Operation op, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
InputType input;
OutputType matchValue;
double ulpTolerance;
OutputType match_value;
double ulp_tolerance;

public:
MPFRMatcher(InputType testInput, double ulpTolerance)
: input(testInput), ulpTolerance(ulpTolerance) {}
MPFRMatcher(InputType testInput, double ulp_tolerance)
: input(testInput), ulp_tolerance(ulp_tolerance) {}

bool match(OutputType libcResult) {
matchValue = libcResult;
return match(input, matchValue, ulpTolerance);
match_value = libcResult;
return match(input, match_value, ulp_tolerance);
}

void explainError(testutils::StreamWrapper &OS) override {
explainError(input, matchValue, OS);
// This method is marked with NOLINT because it the name `explainError`
// does not confirm to the coding style.
void explainError(testutils::StreamWrapper &OS) override { // NOLINT
explain_error(input, match_value, OS);
}

private:
template <typename T> static bool match(T in, T out, double tolerance) {
return compareUnaryOperationSingleOutput(op, in, out, tolerance);
return compare_unary_operation_single_output(op, in, out, tolerance);
}

template <typename T>
static bool match(T in, const BinaryOutput<T> &out, double tolerance) {
return compareUnaryOperationTwoOutputs(op, in, out, tolerance);
return compare_unary_operation_two_outputs(op, in, out, tolerance);
}

template <typename T>
static bool match(const BinaryInput<T> &in, T out, double tolerance) {
return compareBinaryOperationOneOutput(op, in, out, tolerance);
return compare_binary_operation_one_output(op, in, out, tolerance);
}

template <typename T>
static bool match(BinaryInput<T> in, const BinaryOutput<T> &out,
double tolerance) {
return compareBinaryOperationTwoOutputs(op, in, out, tolerance);
return compare_binary_operation_two_outputs(op, in, out, tolerance);
}

template <typename T>
static bool match(const TernaryInput<T> &in, T out, double tolerance) {
return compareTernaryOperationOneOutput(op, in, out, tolerance);
return compare_ternary_operation_one_output(op, in, out, tolerance);
}

template <typename T>
static void explainError(T in, T out, testutils::StreamWrapper &OS) {
explainUnaryOperationSingleOutputError(op, in, out, OS);
static void explain_error(T in, T out, testutils::StreamWrapper &OS) {
explain_unary_operation_single_output_error(op, in, out, OS);
}

template <typename T>
static void explainError(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainUnaryOperationTwoOutputsError(op, in, out, OS);
static void explain_error(T in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_unary_operation_two_outputs_error(op, in, out, OS);
}

template <typename T>
static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explainBinaryOperationTwoOutputsError(op, in, out, OS);
static void explain_error(const BinaryInput<T> &in,
const BinaryOutput<T> &out,
testutils::StreamWrapper &OS) {
explain_binary_operation_two_outputs_error(op, in, out, OS);
}

template <typename T>
static void explainError(const BinaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explainBinaryOperationOneOutputError(op, in, out, OS);
static void explain_error(const BinaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_binary_operation_one_output_error(op, in, out, OS);
}

template <typename T>
static void explainError(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explainTernaryOperationOneOutputError(op, in, out, OS);
static void explain_error(const TernaryInput<T> &in, T out,
testutils::StreamWrapper &OS) {
explain_ternary_operation_one_output_error(op, in, out, OS);
}
};

Expand All @@ -232,7 +237,7 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
// Return true if the input and ouput types for the operation op are valid
// types.
template <Operation op, typename InputType, typename OutputType>
constexpr bool isValidOperation() {
constexpr bool is_valid_operation() {
return (Operation::BeginUnaryOperationsSingleOutput < op &&
op < Operation::EndUnaryOperationsSingleOutput &&
cpp::IsSame<InputType, OutputType>::Value &&
Expand All @@ -248,7 +253,7 @@ constexpr bool isValidOperation() {
(Operation::BeginBinaryOperationsTwoOutputs < op &&
op < Operation::EndBinaryOperationsTwoOutputs &&
internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
OutputType>::value) ||
OutputType>::VALUE) ||
(Operation::BeginTernaryOperationsSingleOuput < op &&
op < Operation::EndTernaryOperationsSingleOutput &&
cpp::IsFloatingPointType<OutputType>::Value &&
Expand All @@ -257,29 +262,29 @@ constexpr bool isValidOperation() {

template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address")))
cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(),
cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(),
internal::MPFRMatcher<op, InputType, OutputType>>
getMPFRMatcher(InputType input, OutputType outputUnused, double t) {
get_mpfr_matcher(InputType input, OutputType output_unused, double t) {
return internal::MPFRMatcher<op, InputType, OutputType>(input, t);
}

enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest };

template <typename T> T Round(T x, RoundingMode mode);
template <typename T> T round(T x, RoundingMode mode);

template <typename T> bool RoundToLong(T x, long &result);
template <typename T> bool RoundToLong(T x, RoundingMode mode, long &result);
template <typename T> bool round_to_long(T x, long &result);
template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);

} // namespace mpfr
} // namespace testing
} // namespace __llvm_libc

#define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \
EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))
#define EXPECT_MPFR_MATCH(op, input, match_value, tolerance) \
EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, tolerance))

#define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \
ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \
input, matchValue, tolerance))
#define ASSERT_MPFR_MATCH(op, input, match_value, tolerance) \
ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \
input, match_value, tolerance))

#endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H