diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 093f26bea1a44..2c0e0883207e1 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -45,6 +45,12 @@ namespace Fortran::evaluate { +// Don't use Kahan extended precision summation any more when folding +// transformational intrinsic functions other than SUM, since it is +// not used in the runtime implementations of those functions and we +// want results to match. +static constexpr bool useKahanSummation{false}; + // Utilities template class Folder { public: diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h index 27b6db1fd8bf0..bd61969a822c3 100644 --- a/flang/lib/Evaluate/fold-matmul.h +++ b/flang/lib/Evaluate/fold-matmul.h @@ -58,18 +58,25 @@ static Expr FoldMatmul(FoldingContext &context, FunctionRef &&funcRef) { Element bElt{mb->At(bAt)}; if constexpr (T::category == TypeCategory::Real || T::category == TypeCategory::Complex) { - // Kahan summation - auto product{aElt.Multiply(bElt, rounding)}; + auto product{aElt.Multiply(bElt)}; overflow |= product.flags.test(RealFlag::Overflow); - auto next{correction.Add(product.value, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(product.value, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(product.value)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } else if constexpr (T::category == TypeCategory::Integer) { + // Don't use Kahan summation in numeric MATMUL folding; + // the runtime doesn't use it, and results should match. auto product{aElt.MultiplySigned(bElt)}; overflow |= product.SignedMultiplicationOverflowed(); auto added{sum.AddSigned(product.lower)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index fd37437c643aa..4df709d3d2c21 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -54,7 +54,7 @@ template class Norm2Accumulator { : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; void operator()( Scalar &element, const ConstantSubscripts &at, bool /*first*/) { - // Kahan summation of scaled elements: + // Summation of scaled elements: // Naively, // NORM2(A(:)) = SQRT(SUM(A(:)**2)) // For any T > 0, we have mathematically @@ -76,24 +76,27 @@ template class Norm2Accumulator { auto item{array_.At(at)}; auto scaled{item.Divide(scale).value}; auto square{scaled.Multiply(scaled).value}; - auto next{square.Add(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; - overflow_ |= sum.flags.test(RealFlag::Overflow); - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; - element = sum.value; + if constexpr (useKahanSummation) { + auto next{square.Add(correction_, rounding_)}; + overflow_ |= next.flags.test(RealFlag::Overflow); + auto sum{element.Add(next.value, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + correction_ = sum.value.Subtract(element, rounding_) + .value.Subtract(next.value, rounding_) + .value; + element = sum.value; + } else { + auto sum{element.Add(square, rounding_)}; + overflow_ |= sum.flags.test(RealFlag::Overflow); + element = sum.value; + } } } bool overflow() const { return overflow_; } void Done(Scalar &result) { - // result+correction == SUM((data(:)/maxAbs)**2) - // result = maxAbs * SQRT(result+correction) - auto corrected{result.Add(correction_, rounding_)}; - overflow_ |= corrected.flags.test(RealFlag::Overflow); - correction_ = Scalar{}; - auto root{corrected.value.SQRT().value}; + // incoming result = SUM((data(:)/maxAbs)**2) + // outgoing result = maxAbs * SQRT(result) + auto root{result.SQRT().value}; auto product{root.Multiply(maxAbs_.At(maxAbsAt_))}; maxAbs_.IncrementSubscripts(maxAbsAt_); overflow_ |= product.flags.test(RealFlag::Overflow); diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index c84d35734ab5a..ae17770dc2961 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -43,17 +43,23 @@ static Expr FoldDotProduct( Expr products{Fold( context, Expr{std::move(conjgA)} * Expr{Constant{*vb}})}; Constant &cProducts{DEREF(UnwrapConstantValue(products))}; - Element correction{}; // Use Kahan summation for greater precision. + [[maybe_unused]] Element correction{}; const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { - auto next{correction.Add(x, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(x, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } } else if constexpr (T::category == TypeCategory::Logical) { Expr conjunctions{Fold(context, @@ -80,17 +86,23 @@ static Expr FoldDotProduct( Expr products{ Fold(context, Expr{Constant{*va}} * Expr{Constant{*vb}})}; Constant &cProducts{DEREF(UnwrapConstantValue(products))}; - Element correction{}; // Use Kahan summation for greater precision. + [[maybe_unused]] Element correction{}; const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { - auto next{correction.Add(x, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; - overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + if constexpr (useKahanSummation) { + auto next{correction.Add(x, rounding)}; + overflow |= next.flags.test(RealFlag::Overflow); + auto added{sum.Add(next.value, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + correction = added.value.Subtract(sum, rounding) + .value.Subtract(next.value, rounding) + .value; + sum = std::move(added.value); + } else { + auto added{sum.Add(x, rounding)}; + overflow |= added.flags.test(RealFlag::Overflow); + sum = std::move(added.value); + } } } if (overflow) {