Skip to content

Commit 47f18af

Browse files
committed
[flang] Fold MAXVAL & MINVAL
Implement constant folding for the reduction transformational intrinsic functions MAXVAL and MINVAL. In anticipation of more folding work to follow, with (I hope) some common infrastructure, these two have been implemented in a new header file. Differential Revision: https://reviews.llvm.org/D104337
1 parent 46446e3 commit 47f18af

File tree

13 files changed

+294
-15
lines changed

13 files changed

+294
-15
lines changed

flang/include/flang/Evaluate/call.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,22 @@ class ProcedureRef {
218218
int Rank() const;
219219
bool IsElemental() const { return proc_.IsElemental(); }
220220
bool hasAlternateReturns() const { return hasAlternateReturns_; }
221+
222+
Expr<SomeType> *UnwrapArgExpr(int n) {
223+
if (static_cast<std::size_t>(n) < arguments_.size() && arguments_[n]) {
224+
return arguments_[n]->UnwrapExpr();
225+
} else {
226+
return nullptr;
227+
}
228+
}
229+
const Expr<SomeType> *UnwrapArgExpr(int n) const {
230+
if (static_cast<std::size_t>(n) < arguments_.size() && arguments_[n]) {
231+
return arguments_[n]->UnwrapExpr();
232+
} else {
233+
return nullptr;
234+
}
235+
}
236+
221237
bool operator==(const ProcedureRef &) const;
222238
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
223239

flang/include/flang/Evaluate/integer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class Integer {
358358

359359
static constexpr int DIGITS{bits - 1}; // don't count the sign bit
360360
static constexpr Integer HUGE() { return MASKR(bits - 1); }
361+
static constexpr Integer Least() { return MASKL(1); }
361362
static constexpr int RANGE{// in the sense of SELECTED_INT_KIND
362363
// This magic value is LOG10(2.)*1E12.
363364
static_cast<int>(((bits - 1) * 301029995664) / 1000000000000)};

flang/include/flang/Evaluate/shape.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ Constant<ExtentType> AsConstantShape(const ConstantSubscripts &);
4848
ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &);
4949
std::optional<ConstantSubscripts> AsConstantExtents(
5050
FoldingContext &, const Shape &);
51+
Shape AsShape(const ConstantSubscripts &);
52+
std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &);
5153

5254
inline int GetRank(const Shape &s) { return static_cast<int>(s.size()); }
5355

@@ -89,6 +91,7 @@ MaybeExtentExpr CountTrips(
8991

9092
// Computes SIZE() == PRODUCT(shape)
9193
MaybeExtentExpr GetSize(Shape &&);
94+
ConstantSubscript GetSize(const ConstantSubscripts &);
9295

9396
// Utility predicate: does an expression reference any implied DO index?
9497
bool ContainsAnyImpliedDoIndex(const ExtentExpr &);

flang/include/flang/Evaluate/tools.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,16 @@ std::optional<Expr<SomeType>> Negation(
644644
std::optional<Expr<LogicalResult>> Relate(parser::ContextualMessages &,
645645
RelationalOperator, Expr<SomeType> &&, Expr<SomeType> &&);
646646

647+
// Create a relational operation between two identically-typed operands
648+
// and wrap it up in an Expr<LogicalResult>.
649+
template <typename T>
650+
Expr<LogicalResult> PackageRelation(
651+
RelationalOperator opr, Expr<T> &&x, Expr<T> &&y) {
652+
static_assert(IsSpecificIntrinsicType<T>);
653+
return Expr<LogicalResult>{
654+
Relational<SomeType>{Relational<T>{opr, std::move(x), std::move(y)}}};
655+
}
656+
647657
template <int K>
648658
Expr<Type<TypeCategory::Logical, K>> LogicalNegation(
649659
Expr<Type<TypeCategory::Logical, K>> &&x) {

flang/lib/Evaluate/fold-character.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,49 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-reduction.h"
1011

1112
namespace Fortran::evaluate {
1213

14+
static std::optional<ConstantSubscript> GetConstantLength(
15+
FoldingContext &context, Expr<SomeType> &&expr) {
16+
expr = Fold(context, std::move(expr));
17+
if (auto *chExpr{UnwrapExpr<Expr<SomeCharacter>>(expr)}) {
18+
if (auto len{chExpr->LEN()}) {
19+
return ToInt64(*len);
20+
}
21+
}
22+
return std::nullopt;
23+
}
24+
25+
template <typename T>
26+
static std::optional<ConstantSubscript> GetConstantLength(
27+
FoldingContext &context, FunctionRef<T> &funcRef, int zeroBasedArg) {
28+
if (auto *expr{funcRef.UnwrapArgExpr(zeroBasedArg)}) {
29+
return GetConstantLength(context, std::move(*expr));
30+
} else {
31+
return std::nullopt;
32+
}
33+
}
34+
35+
template <typename T>
36+
static std::optional<Scalar<T>> Identity(
37+
Scalar<T> str, std::optional<ConstantSubscript> len) {
38+
if (len) {
39+
return CharacterUtils<T::kind>::REPEAT(
40+
str, std::max<ConstantSubscript>(*len, 0));
41+
} else {
42+
return std::nullopt;
43+
}
44+
}
45+
1346
template <int KIND>
1447
Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
1548
FoldingContext &context,
1649
FunctionRef<Type<TypeCategory::Character, KIND>> &&funcRef) {
1750
using T = Type<TypeCategory::Character, KIND>;
51+
using StringType = Scalar<T>; // std::string or larger
52+
using SingleCharType = typename StringType::value_type; // char &c.
1853
auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
1954
CHECK(intrinsic);
2055
std::string name{intrinsic->name};
@@ -32,10 +67,24 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
3267
context, std::move(funcRef), CharacterUtils<KIND>::ADJUSTR);
3368
} else if (name == "max") {
3469
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
70+
} else if (name == "maxval") {
71+
SingleCharType least{0};
72+
if (auto identity{Identity<T>(
73+
StringType{least}, GetConstantLength(context, funcRef, 0))}) {
74+
return FoldMaxvalMinval<T>(
75+
context, std::move(funcRef), RelationalOperator::GT, *identity);
76+
}
3577
} else if (name == "merge") {
3678
return FoldMerge<T>(context, std::move(funcRef));
3779
} else if (name == "min") {
3880
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
81+
} else if (name == "minval") {
82+
auto most{std::numeric_limits<SingleCharType>::max()};
83+
if (auto identity{Identity<T>(
84+
StringType{most}, GetConstantLength(context, funcRef, 0))}) {
85+
return FoldMaxvalMinval<T>(
86+
context, std::move(funcRef), RelationalOperator::LT, *identity);
87+
}
3988
} else if (name == "new_line") {
4089
return Expr<T>{Constant<T>{CharacterUtils<KIND>::NEW_LINE()}};
4190
} else if (name == "repeat") { // not elemental
@@ -52,7 +101,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
52101
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
53102
}
54103
}
55-
// TODO: cshift, eoshift, maxval, minval, pack, reduce,
104+
// TODO: cshift, eoshift, maxloc, minloc, pack, reduce,
56105
// spread, transfer, transpose, unpack
57106
return Expr<T>{std::move(funcRef)};
58107
}

flang/lib/Evaluate/fold-implementation.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,9 @@ template <typename T> Expr<T> Folder<T>::Reshape(FunctionRef<T> &&funcRef) {
600600
template <typename T>
601601
Expr<T> FoldMINorMAX(
602602
FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
603+
static_assert(T::category == TypeCategory::Integer ||
604+
T::category == TypeCategory::Real ||
605+
T::category == TypeCategory::Character);
603606
std::vector<Constant<T> *> constantArgs;
604607
// Call Folding on all arguments, even if some are not constant,
605608
// to make operand promotion explicit.
@@ -608,8 +611,9 @@ Expr<T> FoldMINorMAX(
608611
constantArgs.push_back(cst);
609612
}
610613
}
611-
if (constantArgs.size() != funcRef.arguments().size())
614+
if (constantArgs.size() != funcRef.arguments().size()) {
612615
return Expr<T>(std::move(funcRef));
616+
}
613617
CHECK(constantArgs.size() > 0);
614618
Expr<T> result{std::move(*constantArgs[0])};
615619
for (std::size_t i{1}; i < constantArgs.size(); ++i) {

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-reduction.h"
1011
#include "flang/Evaluate/check-expression.h"
1112

1213
namespace Fortran::evaluate {
@@ -474,6 +475,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
474475
},
475476
sx->u);
476477
}
478+
} else if (name == "maxval") {
479+
return FoldMaxvalMinval<T>(context, std::move(funcRef),
480+
RelationalOperator::GT, T::Scalar::Least());
477481
} else if (name == "merge") {
478482
return FoldMerge<T>(context, std::move(funcRef));
479483
} else if (name == "merge_bits") {
@@ -492,6 +496,9 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
492496
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
493497
} else if (name == "min0" || name == "min1") {
494498
return RewriteSpecificMINorMAX(context, std::move(funcRef));
499+
} else if (name == "minval") {
500+
return FoldMaxvalMinval<T>(
501+
context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
495502
} else if (name == "mod") {
496503
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
497504
ScalarFuncWithContext<T, T, T>(
@@ -650,8 +657,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
650657
// TODO:
651658
// cshift, dot_product, eoshift,
652659
// findloc, iall, iany, iparity, ibits, image_status, ishftc,
653-
// matmul, maxloc, maxval,
654-
// minloc, minval, not, pack, product, reduce,
660+
// matmul, maxloc, minloc, not, pack, product, reduce,
655661
// sign, spread, sum, transfer, transpose, unpack
656662
return Expr<T>{std::move(funcRef)};
657663
}

flang/lib/Evaluate/fold-real.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-reduction.h"
1011

1112
namespace Fortran::evaluate {
1213

@@ -109,10 +110,16 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
109110
return Expr<T>{Scalar<T>::HUGE()};
110111
} else if (name == "max") {
111112
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
113+
} else if (name == "maxval") {
114+
return FoldMaxvalMinval<T>(context, std::move(funcRef),
115+
RelationalOperator::GT, T::Scalar::HUGE().Negate());
112116
} else if (name == "merge") {
113117
return FoldMerge<T>(context, std::move(funcRef));
114118
} else if (name == "min") {
115119
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
120+
} else if (name == "minval") {
121+
return FoldMaxvalMinval<T>(
122+
context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
116123
} else if (name == "real") {
117124
if (auto *expr{args[0].value().UnwrapExpr()}) {
118125
return ToReal<KIND>(context, std::move(*expr));
@@ -124,7 +131,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
124131
return Expr<T>{Scalar<T>::TINY()};
125132
}
126133
// TODO: cshift, dim, dot_product, eoshift, fraction, matmul,
127-
// maxval, minval, modulo, nearest, norm2, pack, product,
134+
// maxloc, minloc, modulo, nearest, norm2, pack, product,
128135
// reduce, rrspacing, scale, set_exponent, spacing, spread,
129136
// sum, transfer, transpose, unpack, bessel_jn (transformational) and
130137
// bessel_yn (transformational)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
//===-- lib/Evaluate/fold-reduction.h -------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY,
10+
// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM
11+
12+
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
13+
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
14+
15+
#include "fold-implementation.h"
16+
17+
namespace Fortran::evaluate {
18+
19+
// MAXVAL & MINVAL
20+
template <typename T>
21+
Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
22+
RelationalOperator opr, Scalar<T> identity) {
23+
static_assert(T::category == TypeCategory::Integer ||
24+
T::category == TypeCategory::Real ||
25+
T::category == TypeCategory::Character);
26+
using Element = typename Constant<T>::Element;
27+
auto &arg{ref.arguments()};
28+
CHECK(arg.size() <= 3);
29+
if (arg.empty()) {
30+
return Expr<T>{std::move(ref)};
31+
}
32+
Constant<T> *array{Folder<T>{context}.Folding(arg[0])};
33+
if (!array || array->Rank() < 1) {
34+
return Expr<T>{std::move(ref)};
35+
}
36+
std::optional<ConstantSubscript> dim;
37+
if (arg.size() >= 2 && arg[1]) {
38+
if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg[1])}) {
39+
if (auto dimScalar{dimConst->GetScalarValue()}) {
40+
dim.emplace(dimScalar->ToInt64());
41+
if (*dim < 1 || *dim > array->Rank()) {
42+
context.messages().Say(
43+
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
44+
static_cast<std::intmax_t>(*dim), array->Rank());
45+
dim.reset();
46+
}
47+
}
48+
}
49+
if (!dim) {
50+
return Expr<T>{std::move(ref)};
51+
}
52+
}
53+
Constant<LogicalResult> *mask{};
54+
if (arg.size() >= 3 && arg[2]) {
55+
mask = Folder<LogicalResult>{context}.Folding(arg[2]);
56+
if (!mask) {
57+
return Expr<T>{std::move(ref)};
58+
}
59+
if (!CheckConformance(context.messages(), AsShape(array->shape()),
60+
AsShape(mask->shape()),
61+
CheckConformanceFlags::RightScalarExpandable, "ARRAY=", "MASK=")
62+
.value_or(false)) {
63+
return Expr<T>{std::move(ref)};
64+
}
65+
}
66+
// Do it
67+
ConstantSubscripts at{array->lbounds()}, maskAt;
68+
bool maskAllFalse{false};
69+
if (mask) {
70+
if (auto scalar{mask->GetScalarValue()}) {
71+
if (scalar->IsTrue()) {
72+
mask = nullptr; // all .TRUE.
73+
} else {
74+
maskAllFalse = true;
75+
}
76+
} else {
77+
maskAt = mask->lbounds();
78+
}
79+
}
80+
std::vector<Element> result;
81+
ConstantSubscripts resultShape; // empty -> scalar
82+
// Internal function to accumulate into result.back().
83+
auto Accumulate{[&]() {
84+
if (!maskAllFalse && (maskAt.empty() || mask->At(maskAt).IsTrue())) {
85+
Expr<LogicalResult> test{
86+
PackageRelation(opr, Expr<T>{Constant<T>{array->At(at)}},
87+
Expr<T>{Constant<T>{result.back()}})};
88+
auto folded{GetScalarConstantValue<LogicalResult>(
89+
test.Rewrite(context, std::move(test)))};
90+
CHECK(folded.has_value());
91+
if (folded->IsTrue()) {
92+
result.back() = array->At(at);
93+
}
94+
}
95+
}};
96+
if (dim) { // DIM= is present, so result is an array
97+
resultShape = array->shape();
98+
resultShape.erase(resultShape.begin() + (*dim - 1));
99+
ConstantSubscript dimExtent{array->shape().at(*dim - 1)};
100+
ConstantSubscript &dimAt{at[*dim - 1]};
101+
ConstantSubscript dimLbound{dimAt};
102+
ConstantSubscript *maskDimAt{maskAt.empty() ? nullptr : &maskAt[*dim - 1]};
103+
ConstantSubscript maskLbound{maskDimAt ? *maskDimAt : 0};
104+
for (auto n{GetSize(resultShape)}; n-- > 0;
105+
IncrementSubscripts(at, array->shape())) {
106+
dimAt = dimLbound;
107+
if (maskDimAt) {
108+
*maskDimAt = maskLbound;
109+
}
110+
result.push_back(identity);
111+
for (ConstantSubscript j{0}; j < dimExtent;
112+
++j, ++dimAt, maskDimAt && ++*maskDimAt) {
113+
Accumulate();
114+
}
115+
if (maskDimAt) {
116+
IncrementSubscripts(maskAt, mask->shape());
117+
}
118+
}
119+
} else { // no DIM=, result is scalar
120+
result.push_back(identity);
121+
for (auto n{array->size()}; n-- > 0;
122+
IncrementSubscripts(at, array->shape())) {
123+
Accumulate();
124+
if (!maskAt.empty()) {
125+
IncrementSubscripts(maskAt, mask->shape());
126+
}
127+
}
128+
}
129+
if constexpr (T::category == TypeCategory::Character) {
130+
return Expr<T>{Constant<T>{static_cast<ConstantSubscript>(identity.size()),
131+
std::move(result), std::move(resultShape)}};
132+
} else {
133+
return Expr<T>{Constant<T>{std::move(result), std::move(resultShape)}};
134+
}
135+
}
136+
137+
} // namespace Fortran::evaluate
138+
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_

0 commit comments

Comments
 (0)