Skip to content

Commit

Permalink
[flang] evaluate: Fold SQRT, HYPOT, & CABS
Browse files Browse the repository at this point in the history
Implement IEEE Real::SQRT() operation, then use it to
also implement Real::HYPOT(), which can then be used directly
to implement Complex::ABS().

Differential Revision: https://reviews.llvm.org/D109250
  • Loading branch information
klausler committed Sep 7, 2021
1 parent ea04bf3 commit c9e9635
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 16 deletions.
6 changes: 5 additions & 1 deletion flang/include/flang/Evaluate/complex.h
Expand Up @@ -77,6 +77,11 @@ template <typename REAL_TYPE> class Complex {
ValueWithRealFlags<Complex> Divide(
const Complex &, Rounding rounding = defaultRounding) const;

// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
ValueWithRealFlags<Part> ABS(Rounding rounding = defaultRounding) const {
return re_.HYPOT(im_, rounding);
}

constexpr Complex FlushSubnormalToZero() const {
return {re_.FlushSubnormalToZero(), im_.FlushSubnormalToZero()};
}
Expand All @@ -88,7 +93,6 @@ template <typename REAL_TYPE> class Complex {
std::string DumpHexadecimal() const;
llvm::raw_ostream &AsFortran(llvm::raw_ostream &, int kind) const;

// TODO: (C)ABS once Real::HYPOT is done
// TODO: unit testing

private:
Expand Down
6 changes: 4 additions & 2 deletions flang/include/flang/Evaluate/real.h
Expand Up @@ -115,8 +115,10 @@ class Real : public common::RealDetails<PREC> {
ValueWithRealFlags<Real> Divide(
const Real &, Rounding rounding = defaultRounding) const;

// SQRT(x**2 + y**2) but computed so as to avoid spurious overflow
// TODO: not yet implemented; needed for CABS
ValueWithRealFlags<Real> SQRT(Rounding rounding = defaultRounding) const;

// HYPOT(x,y)=SQRT(x**2 + y**2) computed so as to avoid spurious
// intermediate overflows.
ValueWithRealFlags<Real> HYPOT(
const Real &, Rounding rounding = defaultRounding) const;

Expand Down
29 changes: 18 additions & 11 deletions flang/lib/Evaluate/fold-real.cpp
Expand Up @@ -27,8 +27,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" ||
name == "erfc" || name == "erfc_scaled" || name == "exp" ||
name == "gamma" || name == "log" || name == "log10" ||
name == "log_gamma" || name == "sin" || name == "sinh" ||
name == "sqrt" || name == "tan" || name == "tanh") {
name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" ||
name == "tanh") {
CHECK(args.size() == 1);
if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) {
return FoldElementalIntrinsic<T, T>(
Expand All @@ -40,8 +40,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "amax0" || name == "amin0" || name == "amin1" ||
name == "amax1" || name == "dmin1" || name == "dmax1") {
return RewriteSpecificMINorMAX(context, std::move(funcRef));
} else if (name == "atan" || name == "atan2" || name == "hypot" ||
name == "mod") {
} else if (name == "atan" || name == "atan2" || name == "mod") {
std::string localName{name == "atan" ? "atan2" : name};
CHECK(args.size() == 2);
if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) {
Expand Down Expand Up @@ -71,13 +70,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return FoldElementalIntrinsic<T, T>(
context, std::move(funcRef), &Scalar<T>::ABS);
} else if (auto *z{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
if (auto callable{GetHostRuntimeWrapper<T, ComplexT>("abs")}) {
return FoldElementalIntrinsic<T, ComplexT>(
context, std::move(funcRef), *callable);
} else {
context.messages().Say(
"abs(complex(kind=%d)) cannot be folded on host"_en_US, KIND);
}
return FoldElementalIntrinsic<T, ComplexT>(context, std::move(funcRef),
ScalarFunc<T, ComplexT>([](const Scalar<ComplexT> &z) -> Scalar<T> {
return z.ABS().value;
}));
} else {
common::die(" unexpected argument type inside abs");
}
Expand Down Expand Up @@ -108,6 +104,13 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return Expr<T>{Scalar<T>::EPSILON()};
} else if (name == "huge") {
return Expr<T>{Scalar<T>::HUGE()};
} else if (name == "hypot") {
CHECK(args.size() == 2);
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
ScalarFunc<T, T, T>(
[](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> {
return x.HYPOT(y).value;
}));
} else if (name == "max") {
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
} else if (name == "maxval") {
Expand All @@ -130,6 +133,10 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "sign") {
return FoldElementalIntrinsic<T, T, T>(
context, std::move(funcRef), &Scalar<T>::SIGN);
} else if (name == "sqrt") {
return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
ScalarFunc<T, T>(
[](const Scalar<T> &x) -> Scalar<T> { return x.SQRT().value; }));
} else if (name == "sum") {
return FoldSum<T>(context, std::move(funcRef));
} else if (name == "tiny") {
Expand Down
2 changes: 0 additions & 2 deletions flang/lib/Evaluate/intrinsics-library.cpp
Expand Up @@ -222,15 +222,13 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
FolderFactory<F, F{std::erfc}>::Create("erfc"),
FolderFactory<F, F{std::exp}>::Create("exp"),
FolderFactory<F, F{std::tgamma}>::Create("gamma"),
FolderFactory<F2, F2{std::hypot}>::Create("hypot"),
FolderFactory<F, F{std::log}>::Create("log"),
FolderFactory<F, F{std::log10}>::Create("log10"),
FolderFactory<F, F{std::lgamma}>::Create("log_gamma"),
FolderFactory<F2, F2{std::fmod}>::Create("mod"),
FolderFactory<F2, F2{std::pow}>::Create("pow"),
FolderFactory<F, F{std::sin}>::Create("sin"),
FolderFactory<F, F{std::sinh}>::Create("sinh"),
FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
FolderFactory<F, F{std::tan}>::Create("tan"),
FolderFactory<F, F{std::tanh}>::Create("tanh"),
};
Expand Down
101 changes: 101 additions & 0 deletions flang/lib/Evaluate/real.cpp
Expand Up @@ -261,6 +261,107 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::Divide(
return result;
}

template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
ValueWithRealFlags<Real> result;
if (IsNotANumber()) {
result.value = NotANumber();
if (IsSignalingNaN()) {
result.flags.set(RealFlag::InvalidArgument);
}
} else if (IsNegative()) {
if (IsZero()) {
// SQRT(-0) == -0 in IEEE-754.
result.value.word_ = result.value.word_.IBSET(bits - 1);
} else {
result.value = NotANumber();
}
} else if (IsInfinite()) {
// SQRT(+Inf) == +Inf
result.value = Infinity(false);
} else {
// Slow but reliable bit-at-a-time method. Start with a clear significand
// and half the unbiased exponent, and then try to set significand bits
// in descending order of magnitude without exceeding the exact result.
int expo{UnbiasedExponent()};
if (IsSubnormal()) {
expo -= GetFraction().LEADZ();
}
expo = expo / 2 + exponentBias;
result.value.Normalize(false, expo, Fraction::MASKL(1));
for (int bit{significandBits - 1}; bit >= 0; --bit) {
Word word{result.value.word_};
result.value.word_ = word.IBSET(bit);
auto squared{result.value.Multiply(result.value, rounding)};
if (squared.flags.test(RealFlag::Overflow) ||
squared.flags.test(RealFlag::Underflow) ||
Compare(squared.value) == Relation::Less) {
result.value.word_ = word;
}
}
// The computed square root, when squared, has a square that's not greater
// than the original argument. Check this square against the square of the
// next Real value, and return that one if its square is closer in magnitude
// to the original argument.
Real resultSq{result.value.Multiply(result.value).value};
Real diff{Subtract(resultSq).value.ABS()};
if (diff.IsZero()) {
return result; // exact
}
Real ulp;
ulp.Normalize(false, expo, Fraction::MASKR(1));
Real nextAfter{result.value.Add(ulp).value};
auto nextAfterSq{nextAfter.Multiply(nextAfter)};
if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
!nextAfterSq.flags.test(RealFlag::Underflow)) {
Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
if (nextAfterDiff.Compare(diff) == Relation::Less) {
result.value = nextAfter;
if (nextAfterDiff.IsZero()) {
return result; // exact
}
}
}
result.flags.set(RealFlag::Inexact);
}
return result;
}

// HYPOT(x,y) = SQRT(x**2 + y**2) by definition, but those squared intermediate
// values are susceptible to over/underflow when computed naively.
// Assuming that x>=y, calculate instead:
// HYPOT(x,y) = SQRT(x**2 * (1+(y/x)**2))
// = ABS(x) * SQRT(1+(y/x)**2)
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::HYPOT(
const Real &y, Rounding rounding) const {
ValueWithRealFlags<Real> result;
if (IsNotANumber() || y.IsNotANumber()) {
result.flags.set(RealFlag::InvalidArgument);
result.value = NotANumber();
} else if (ABS().Compare(y.ABS()) == Relation::Less) {
return y.HYPOT(*this);
} else if (IsZero()) {
return result; // x==y==0
} else {
auto yOverX{y.Divide(*this, rounding)}; // y/x
bool inexact{yOverX.flags.test(RealFlag::Inexact)};
auto squared{yOverX.value.Multiply(yOverX.value, rounding)}; // (y/x)**2
inexact |= squared.flags.test(RealFlag::Inexact);
Real one;
one.Normalize(false, exponentBias, Fraction::MASKL(1)); // 1.0
auto sum{squared.value.Add(one, rounding)}; // 1.0 + (y/x)**2
inexact |= sum.flags.test(RealFlag::Inexact);
auto sqrt{sum.value.SQRT()};
inexact |= sqrt.flags.test(RealFlag::Inexact);
result = sqrt.value.Multiply(ABS(), rounding);
if (inexact) {
result.flags.set(RealFlag::Inexact);
}
}
return result;
}

template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::ToWholeNumber(
common::RoundingMode mode) const {
Expand Down
40 changes: 40 additions & 0 deletions flang/test/Evaluate/folding28.f90
@@ -0,0 +1,40 @@
! RUN: %S/test_folding.sh %s %t %flang_fc1
! REQUIRES: shell
! Tests folding of SQRT()
module m
implicit none
! +Inf
real(8), parameter :: inf8 = z'7ff0000000000000'
logical, parameter :: test_inf8 = sqrt(inf8) == inf8
! max finite
real(8), parameter :: h8 = huge(1.0_8), h8z = z'7fefffffffffffff'
logical, parameter :: test_h8 = h8 == h8z
real(8), parameter :: sqrt_h8 = sqrt(h8), sqrt_h8z = z'5fefffffffffffff'
logical, parameter :: test_sqrt_h8 = sqrt_h8 == sqrt_h8z
real(8), parameter :: sqr_sqrt_h8 = sqrt_h8 * sqrt_h8, sqr_sqrt_h8z = z'7feffffffffffffe'
logical, parameter :: test_sqr_sqrt_h8 = sqr_sqrt_h8 == sqr_sqrt_h8z
! -0 (sqrt is -0)
real(8), parameter :: n08 = z'8000000000000000'
real(8), parameter :: sqrt_n08 = sqrt(n08)
!WARN: division by zero
real(8), parameter :: inf_n08 = 1.0_8 / sqrt_n08, inf_n08z = z'fff0000000000000'
logical, parameter :: test_n08 = inf_n08 == inf_n08z
! min normal
real(8), parameter :: t8 = tiny(1.0_8), t8z = z'0010000000000000'
logical, parameter :: test_t8 = t8 == t8z
real(8), parameter :: sqrt_t8 = sqrt(t8), sqrt_t8z = z'2000000000000000'
logical, parameter :: test_sqrt_t8 = sqrt_t8 == sqrt_t8z
real(8), parameter :: sqr_sqrt_t8 = sqrt_t8 * sqrt_t8
logical, parameter :: test_sqr_sqrt_t8 = sqr_sqrt_t8 == t8
! max subnormal
real(8), parameter :: maxs8 = z'000fffffffffffff'
real(8), parameter :: sqrt_maxs8 = sqrt(maxs8), sqrt_maxs8z = z'2000000000000000'
logical, parameter :: test_sqrt_maxs8 = sqrt_maxs8 == sqrt_maxs8z
! min subnormal
real(8), parameter :: mins8 = z'1'
real(8), parameter :: sqrt_mins8 = sqrt(mins8), sqrt_mins8z = z'1e60000000000000'
logical, parameter :: test_sqrt_mins8 = sqrt_mins8 == sqrt_mins8z
real(8), parameter :: sqr_sqrt_mins8 = sqrt_mins8 * sqrt_mins8
logical, parameter :: test_sqr_sqrt_mins8 = sqr_sqrt_mins8 == mins8
end module

0 comments on commit c9e9635

Please sign in to comment.