diff --git a/flang/lib/Evaluate/fold-character.cpp b/flang/lib/Evaluate/fold-character.cpp index 2a55334866aa8..a599815fa7aee 100644 --- a/flang/lib/Evaluate/fold-character.cpp +++ b/flang/lib/Evaluate/fold-character.cpp @@ -80,8 +80,6 @@ Expr> FoldIntrinsicFunction( return FoldMaxvalMinval( context, std::move(funcRef), RelationalOperator::GT, *identity); } - } else if (name == "merge") { - return FoldMerge(context, std::move(funcRef)); } else if (name == "min") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); } else if (name == "minval") { diff --git a/flang/lib/Evaluate/fold-complex.cpp b/flang/lib/Evaluate/fold-complex.cpp index efdb18f889132..520121ad254de 100644 --- a/flang/lib/Evaluate/fold-complex.cpp +++ b/flang/lib/Evaluate/fold-complex.cpp @@ -64,8 +64,6 @@ Expr> FoldIntrinsicFunction( } } else if (name == "dot_product") { return FoldDotProduct(context, std::move(funcRef)); - } else if (name == "merge") { - return FoldMerge(context, std::move(funcRef)); } else if (name == "product") { auto one{Scalar::FromInteger(value::Integer<8>{1}).value}; return FoldProduct(context, std::move(funcRef), Scalar{one}); diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index aaa13ec371753..c47a22c99a457 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -64,6 +64,7 @@ template class Folder { Expr CSHIFT(FunctionRef &&); Expr EOSHIFT(FunctionRef &&); + Expr MERGE(FunctionRef &&); Expr PACK(FunctionRef &&); Expr RESHAPE(FunctionRef &&); Expr SPREAD(FunctionRef &&); @@ -397,9 +398,11 @@ template Expr Folder::Folding(Designator &&designator) { template Constant *Folder::Folding(std::optional &arg) { if (auto *expr{UnwrapExpr>(arg)}) { - if (!UnwrapExpr>(*expr)) { - if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) { - *expr = Fold(context_, std::move(*converted)); + if constexpr (T::category != TypeCategory::Derived) { + if (!UnwrapExpr>(*expr)) { + if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) { + *expr = Fold(context_, std::move(*converted)); + } } } return UnwrapConstantValue(*expr); @@ -411,8 +414,6 @@ template std::optional *...>> GetConstantArgumentsHelper( FoldingContext &context, ActualArguments &arguments, std::index_sequence) { - static_assert( - (... && IsSpecificIntrinsicType)); // TODO derived types for MERGE? static_assert(sizeof...(A) > 0); std::tuple *...> args{ Folder{context}.Folding(arguments.at(I))...}; @@ -489,7 +490,6 @@ Expr FoldElementalIntrinsicHelper(FoldingContext &context, } } CHECK(rank == GetRank(shape)); - // Compute all the scalar values of the results std::vector> results; if (TotalElementCount(shape) > 0) { @@ -513,6 +513,13 @@ Expr FoldElementalIntrinsicHelper(FoldingContext &context, auto len{static_cast( results.empty() ? 0 : results[0].length())}; return Expr{Constant{len, std::move(results), std::move(shape)}}; + } else if constexpr (TR::category == TypeCategory::Derived) { + if (!results.empty()) { + return Expr{rank == 0 + ? Constant{results.front()} + : Constant{results.front().derivedTypeSpec(), + std::move(results), std::move(shape)}}; + } } else { return Expr{Constant{std::move(results), std::move(shape)}}; } @@ -780,6 +787,16 @@ template Expr Folder::EOSHIFT(FunctionRef &&funcRef) { return MakeInvalidIntrinsic(std::move(funcRef)); } +template Expr Folder::MERGE(FunctionRef &&funcRef) { + return FoldElementalIntrinsic(context_, + std::move(funcRef), + ScalarFunc( + [](const Scalar &ifTrue, const Scalar &ifFalse, + const Scalar &predicate) -> Scalar { + return predicate.IsTrue() ? ifTrue : ifFalse; + })); +} + template Expr Folder::PACK(FunctionRef &&funcRef) { auto args{funcRef.arguments()}; CHECK(args.size() == 3); @@ -1126,6 +1143,8 @@ Expr FoldOperation(FoldingContext &context, FunctionRef &&funcRef) { return Folder{context}.CSHIFT(std::move(funcRef)); } else if (name == "eoshift") { return Folder{context}.EOSHIFT(std::move(funcRef)); + } else if (name == "merge") { + return Folder{context}.MERGE(std::move(funcRef)); } else if (name == "pack") { return Folder{context}.PACK(std::move(funcRef)); } else if (name == "reshape") { @@ -1147,17 +1166,6 @@ Expr FoldOperation(FoldingContext &context, FunctionRef &&funcRef) { return Expr{std::move(funcRef)}; } -template -Expr FoldMerge(FoldingContext &context, FunctionRef &&funcRef) { - return FoldElementalIntrinsic(context, - std::move(funcRef), - ScalarFunc( - [](const Scalar &ifTrue, const Scalar &ifFalse, - const Scalar &predicate) -> Scalar { - return predicate.IsTrue() ? ifTrue : ifFalse; - })); -} - Expr FoldOperation(FoldingContext &, ImpliedDoIndex &&); // Array constructor folding diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index 02d5ea5a133ad..53659d2c36d7c 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -1038,8 +1038,6 @@ Expr> FoldIntrinsicFunction( } else if (name == "maxval") { return FoldMaxvalMinval(context, std::move(funcRef), RelationalOperator::GT, T::Scalar::Least()); - } else if (name == "merge") { - return FoldMerge(context, std::move(funcRef)); } else if (name == "merge_bits") { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::MERGE_BITS); diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp index 129a8fc40577d..0803c86836811 100644 --- a/flang/lib/Evaluate/fold-logical.cpp +++ b/flang/lib/Evaluate/fold-logical.cpp @@ -215,8 +215,6 @@ Expr> FoldIntrinsicFunction( if (auto *expr{UnwrapExpr>(args[0])}) { return Fold(context, ConvertToType(std::move(*expr))); } - } else if (name == "merge") { - return FoldMerge(context, std::move(funcRef)); } else if (name == "parity") { return FoldAllAnyParity( context, std::move(funcRef), &Scalar::NEQV, Scalar{false}); diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 01a97951b0412..671d897ef7b2f 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -184,8 +184,6 @@ Expr> FoldIntrinsicFunction( } else if (name == "maxval") { return FoldMaxvalMinval(context, std::move(funcRef), RelationalOperator::GT, T::Scalar::HUGE().Negate()); - } else if (name == "merge") { - return FoldMerge(context, std::move(funcRef)); } else if (name == "min") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); } else if (name == "minval") { diff --git a/flang/test/Evaluate/fold-merge.f90 b/flang/test/Evaluate/fold-merge.f90 new file mode 100644 index 0000000000000..9cbd0ca7f2a99 --- /dev/null +++ b/flang/test/Evaluate/fold-merge.f90 @@ -0,0 +1,22 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of MERGE +module m + type t + integer n + end type + logical, parameter :: test_01 = all(merge([1,2,3],4,[.true.,.false.,.true.]) == [1,4,3]) + logical, parameter :: test_02 = all(merge([1,2,3],4,.true.) == [1,2,3]) + logical, parameter :: test_03 = all(merge([1,2,3],4,.false.) == [4,4,4]) + logical, parameter :: test_04 = all(merge(1,4,[.true.,.false.,.true.,.false.]) == [1,4,1,4]) + type(t), parameter :: dt00a = merge(t(1),t(2),.true.) + logical, parameter :: test_05 = dt00a%n == 1 + type(t), parameter :: dt00b = merge(t(1),t(2),.false.) + logical, parameter :: test_06 = dt00b%n == 2 + type(t), parameter :: dt01(*) = merge([t(1),t(2)],[t(3),t(4)],[.false.,.true.]) + logical, parameter :: test_07 = all(dt01%n == [3,2]) + type(t), parameter :: dt02(*) = merge(t(1),[t(3),t(4)],.true.) + logical, parameter :: test_08 = all(dt02%n == [1,1]) + type(t), parameter :: dt03(*) = merge([t(1),t(2)],t(3),[.true.,.false.]) + logical, parameter :: test_09 = all(dt03%n == [1,3]) + logical, parameter :: test_10 = merge('ab','cd',.true.) == 'ab' +end