|
| 1 | +//===-- lib/Evaluate/fold-reduction.h -------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY, |
| 10 | +// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM |
| 11 | + |
| 12 | +#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
| 13 | +#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
| 14 | + |
| 15 | +#include "fold-implementation.h" |
| 16 | + |
| 17 | +namespace Fortran::evaluate { |
| 18 | + |
| 19 | +// MAXVAL & MINVAL |
| 20 | +template <typename T> |
| 21 | +Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref, |
| 22 | + RelationalOperator opr, Scalar<T> identity) { |
| 23 | + static_assert(T::category == TypeCategory::Integer || |
| 24 | + T::category == TypeCategory::Real || |
| 25 | + T::category == TypeCategory::Character); |
| 26 | + using Element = typename Constant<T>::Element; |
| 27 | + auto &arg{ref.arguments()}; |
| 28 | + CHECK(arg.size() <= 3); |
| 29 | + if (arg.empty()) { |
| 30 | + return Expr<T>{std::move(ref)}; |
| 31 | + } |
| 32 | + Constant<T> *array{Folder<T>{context}.Folding(arg[0])}; |
| 33 | + if (!array || array->Rank() < 1) { |
| 34 | + return Expr<T>{std::move(ref)}; |
| 35 | + } |
| 36 | + std::optional<ConstantSubscript> dim; |
| 37 | + if (arg.size() >= 2 && arg[1]) { |
| 38 | + if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg[1])}) { |
| 39 | + if (auto dimScalar{dimConst->GetScalarValue()}) { |
| 40 | + dim.emplace(dimScalar->ToInt64()); |
| 41 | + if (*dim < 1 || *dim > array->Rank()) { |
| 42 | + context.messages().Say( |
| 43 | + "DIM=%jd is not valid for an array of rank %d"_err_en_US, |
| 44 | + static_cast<std::intmax_t>(*dim), array->Rank()); |
| 45 | + dim.reset(); |
| 46 | + } |
| 47 | + } |
| 48 | + } |
| 49 | + if (!dim) { |
| 50 | + return Expr<T>{std::move(ref)}; |
| 51 | + } |
| 52 | + } |
| 53 | + Constant<LogicalResult> *mask{}; |
| 54 | + if (arg.size() >= 3 && arg[2]) { |
| 55 | + mask = Folder<LogicalResult>{context}.Folding(arg[2]); |
| 56 | + if (!mask) { |
| 57 | + return Expr<T>{std::move(ref)}; |
| 58 | + } |
| 59 | + if (!CheckConformance(context.messages(), AsShape(array->shape()), |
| 60 | + AsShape(mask->shape()), |
| 61 | + CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=") |
| 62 | + .value_or(false)) { |
| 63 | + return Expr<T>{std::move(ref)}; |
| 64 | + } |
| 65 | + } |
| 66 | + // Do it |
| 67 | + ConstantSubscripts at{array->lbounds()}, maskAt; |
| 68 | + bool maskAllFalse{false}; |
| 69 | + if (mask) { |
| 70 | + if (auto scalar{mask->GetScalarValue()}) { |
| 71 | + if (scalar->IsTrue()) { |
| 72 | + mask = nullptr; // all .TRUE. |
| 73 | + } else { |
| 74 | + maskAllFalse = true; |
| 75 | + } |
| 76 | + } else { |
| 77 | + maskAt = mask->lbounds(); |
| 78 | + } |
| 79 | + } |
| 80 | + std::vector<Element> result; |
| 81 | + ConstantSubscripts resultShape; // empty -> scalar |
| 82 | + // Internal function to accumulate into result.back(). |
| 83 | + auto Accumulate{[&]() { |
| 84 | + if (!maskAllFalse && (maskAt.empty() || mask->At(maskAt).IsTrue())) { |
| 85 | + Expr<LogicalResult> test{ |
| 86 | + PackageRelation(opr, Expr<T>{Constant<T>{array->At(at)}}, |
| 87 | + Expr<T>{Constant<T>{result.back()}})}; |
| 88 | + auto folded{GetScalarConstantValue<LogicalResult>( |
| 89 | + test.Rewrite(context, std::move(test)))}; |
| 90 | + CHECK(folded.has_value()); |
| 91 | + if (folded->IsTrue()) { |
| 92 | + result.back() = array->At(at); |
| 93 | + } |
| 94 | + } |
| 95 | + }}; |
| 96 | + if (dim) { // DIM= is present, so result is an array |
| 97 | + resultShape = array->shape(); |
| 98 | + resultShape.erase(resultShape.begin() + (*dim - 1)); |
| 99 | + ConstantSubscript dimExtent{array->shape().at(*dim - 1)}; |
| 100 | + ConstantSubscript &dimAt{at[*dim - 1]}; |
| 101 | + ConstantSubscript dimLbound{dimAt}; |
| 102 | + ConstantSubscript *maskDimAt{maskAt.empty() ? nullptr : &maskAt[*dim - 1]}; |
| 103 | + ConstantSubscript maskLbound{maskDimAt ? *maskDimAt : 0}; |
| 104 | + for (auto n{GetSize(resultShape)}; n-- > 0; |
| 105 | + IncrementSubscripts(at, array->shape())) { |
| 106 | + dimAt = dimLbound; |
| 107 | + if (maskDimAt) { |
| 108 | + *maskDimAt = maskLbound; |
| 109 | + } |
| 110 | + result.push_back(identity); |
| 111 | + for (ConstantSubscript j{0}; j < dimExtent; |
| 112 | + ++j, ++dimAt, maskDimAt && ++*maskDimAt) { |
| 113 | + Accumulate(); |
| 114 | + } |
| 115 | + if (maskDimAt) { |
| 116 | + IncrementSubscripts(maskAt, mask->shape()); |
| 117 | + } |
| 118 | + } |
| 119 | + } else { // no DIM=, result is scalar |
| 120 | + result.push_back(identity); |
| 121 | + for (auto n{array->size()}; n-- > 0; |
| 122 | + IncrementSubscripts(at, array->shape())) { |
| 123 | + Accumulate(); |
| 124 | + if (!maskAt.empty()) { |
| 125 | + IncrementSubscripts(maskAt, mask->shape()); |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + if constexpr (T::category == TypeCategory::Character) { |
| 130 | + return Expr<T>{Constant<T>{static_cast<ConstantSubscript>(identity.size()), |
| 131 | + std::move(result), std::move(resultShape)}}; |
| 132 | + } else { |
| 133 | + return Expr<T>{Constant<T>{std::move(result), std::move(resultShape)}}; |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +} // namespace Fortran::evaluate |
| 138 | +#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_ |
0 commit comments