From 79912d7eae1a6fe01946eadcf498d652f94e12e2 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 4 Sep 2025 14:29:55 -0500 Subject: [PATCH] [flang][OpenMP] Reassociate logical ATOMIC update expressions This is a follow-up to PR153488 and PR155840, this time for expressions of logical type. The handling of logical operations in Expr differs slightly from regular arithmetic operations. The difference is that the specific operation (e.g. and, or, etc.) is not a part of the type, but stored as a data member. Both the matching code and the reconstruction code needed to be extended to correctly handle the data member. This fixes https://github.com/llvm/llvm-project/issues/144944 --- flang/include/flang/Evaluate/match.h | 63 +++++++- flang/lib/Semantics/check-omp-atomic.cpp | 82 +++++++---- .../OpenMP/atomic-update-reassoc-logical.f90 | 137 ++++++++++++++++++ 3 files changed, 252 insertions(+), 30 deletions(-) create mode 100644 flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h index 01932226fa500..32a4a7409fba7 100644 --- a/flang/include/flang/Evaluate/match.h +++ b/flang/include/flang/Evaluate/match.h @@ -11,6 +11,7 @@ #include "flang/Common/Fortran-consts.h" #include "flang/Common/visit.h" #include "flang/Evaluate/expression.h" +#include "flang/Support/Fortran.h" #include "llvm/ADT/STLExtras.h" #include @@ -86,9 +87,12 @@ template struct TypePattern { mutable const MatchType *ref{nullptr}; }; -/// Matches one of the patterns provided as template arguments. All of these -/// patterns should have the same number of operands, i.e. they all should -/// try to match input expression with the same number of children, i.e. +/// Matches one of the patterns provided as template arguments. +/// Upon creation of an AnyOfPattern object with some arguments, say args, +/// each of the pattern objects will be created using args as arguments to +/// the constructor. This means that each of the patterns should be +/// constructible from args, in particular all patterns should take the same +/// number of inputs. So, for example, /// AnyOfPattern is ok, whereas /// AnyOfPattern is not. template struct AnyOfPattern { @@ -178,9 +182,51 @@ struct OperationPattern : public TypePattern { }; template -OperationPattern(const Ops &...ops, llvm::type_identity) +OperationPattern(const Ops &..., llvm::type_identity) -> OperationPattern; +// Encode the actual operator in the type, so that the class is constructible +// only from operand patterns. This will make it usable in AnyOfPattern. +template +struct LogicalOperationPattern + : public OperationPattern, Ops...> { + using Base = OperationPattern, Ops...>; + static constexpr common::LogicalOperator opCode{Operator}; + +private: + template bool matchOp(const LogicalOperation &op) const { + if constexpr (ValType::kind == K) { + return op.logicalOperator == opCode; + } + return false; + } + template bool matchOp(const U &) const { return false; } + +public: + LogicalOperationPattern(const Ops &...ops, llvm::type_identity = {}) + : Base(ops...) {} + + template bool match(const evaluate::Expr &input) const { + // All logical operations (for a given type T) have the same operation + // type (LogicalOperation), so the type-based matching will not + // be able to tell specific operations from one another. + // Check the operation code first, if that matches then use the the + // base class's match. + if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) { + return Base::match(input); + } else { + return false; + } + } + + template bool match(const U &input) const { // + return false; + } +}; + +// No deduction guide for LogicalOperationPattern, since the "Operator" +// parameter cannot be deduced from the constructor arguments. + // Namespace-level definitions template using Expr = ExprPattern; @@ -188,6 +234,15 @@ template using Expr = ExprPattern; template using Op = OperationPattern; +template +using LogicalOp = LogicalOperationPattern; + +template +LogicalOp logical(const Op0 &op0, const Op1 &op1) { + return LogicalOp(op0, op1); +} + template bool match(const Pattern &pattern, const Input &input) { return pattern.match(input); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index f25497ece61c4..ab8aa5f342e48 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -61,8 +61,7 @@ template struct IsIntegral> { static constexpr bool value{// C == common::TypeCategory::Integer || - C == common::TypeCategory::Unsigned || - C == common::TypeCategory::Logical}; + C == common::TypeCategory::Unsigned}; }; template constexpr bool is_integral_v{IsIntegral::value}; @@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint::value}; template constexpr bool is_numeric_v{is_integral_v || is_floating_point_v}; +template struct IsLogical { + static constexpr bool value{false}; +}; + +template +struct IsLogical> { + static constexpr bool value{C == common::TypeCategory::Logical}; +}; + +template constexpr bool is_logical_v{IsLogical::value}; + template using ReassocOpBase = evaluate::match::AnyOfPattern< // evaluate::match::Add, // - evaluate::match::Mul>; + evaluate::match::Mul, // + evaluate::match::LogicalOp, + evaluate::match::LogicalOp, + evaluate::match::LogicalOp, + evaluate::match::LogicalOp>; template struct ReassocOp : public ReassocOpBase { @@ -110,8 +124,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // Try to find cases where the input expression is of the form // (1) (a . b) . c, or // (2) a . (b . c), - // where . denotes an associative operation (currently + or *), and a, b, c - // are some subexpresions. + // where . denotes an associative operation, and a, b, c are some + // subexpresions. // If one of the operands in the nested operation is the atomic variable // (with some possible type conversions applied to it), bring it to the // top-level operation, and move the top-level operand into the nested @@ -119,7 +133,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // For example, assuming x is the atomic variable: // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b. template >> + typename = std::enable_if_t || is_logical_v>> evaluate::Expr operator()(evaluate::Expr &&x, const U &u) { if constexpr (is_floating_point_v) { if (!context_.langOptions().AssociativeMath) { @@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // some order) from the example above. evaluate::match::Expr sub[3]; auto inner{reassocOp(sub[0], sub[1])}; - auto outer1{reassocOp(inner, sub[2])}; // inner + something - auto outer2{reassocOp(sub[2], inner)}; // something + inner + auto outer1{reassocOp(inner, sub[2])}; // inner . something + auto outer2{reassocOp(sub[2], inner)}; // something . inner #if !defined(__clang__) && !defined(_MSC_VER) && \ (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5)) // If GCC version < 8.5, use this definition. For the other definition @@ -167,23 +181,9 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } return common::visit( [&](auto &&s) { - using Expr = evaluate::Expr; - using TypeS = llvm::remove_cvref_t; - // This visitor has to be semantically correct for all possible - // types of s even though at runtime s will only be one of the - // matched types. - // Limit the construction to the operation types that we tried - // to match (otherwise TypeS(op1, op2) would fail for non-binary - // operations). - if constexpr (common::HasMember) { - Expr atom{*sub[atomIdx].ref}; - Expr op1{*sub[(atomIdx + 1) % 3].ref}; - Expr op2{*sub[(atomIdx + 2) % 3].ref}; - return Expr( - TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2))))); - } else { - return Expr(TypeS(s)); - } + // Build the new expression from the matched components. + return Reconstruct(s, *sub[atomIdx].ref, + *sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref); }, evaluate::match::deparen(x).u); } @@ -191,13 +191,43 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } template >> + typename = std::enable_if_t && !is_logical_v>> evaluate::Expr operator()( evaluate::Expr &&x, const U &u, NonIntegralTag = {}) { return Id::operator()(std::move(x), u); } private: + template + evaluate::Expr Reconstruct(const S &op, evaluate::Expr atom, + evaluate::Expr op1, evaluate::Expr op2) { + using TypeS = llvm::remove_cvref_t; + // This function has to be semantically correct for all possible types + // of S even though at runtime s will only be one of the matched types. + // Limit the construction to the operation types that we tried to match + // (otherwise TypeS(op1, op2) would fail for non-binary operations). + if constexpr (!common::HasMember) { + return evaluate::Expr(TypeS(op)); + } else if constexpr (is_logical_v) { + constexpr int K{T::kind}; + if constexpr (std::is_same_v>) { + // Logical operators take an extra argument in their constructor, + // so they need their own reconstruction code. + common::LogicalOperator opCode{op.logicalOperator}; + return evaluate::Expr(TypeS( // + opCode, std::move(atom), + evaluate::Expr(TypeS( // + opCode, std::move(op1), std::move(op2))))); + } + } else { + // Generic reconstruction. + return evaluate::Expr(TypeS( // + std::move(atom), + evaluate::Expr(TypeS( // + std::move(op1), std::move(op2))))); + } + } + template bool IsAtom(const evaluate::Expr &x) const { return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_); } diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 new file mode 100644 index 0000000000000..ccde4fed12f2f --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 @@ -0,0 +1,137 @@ +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +subroutine f00(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .and. y .and. z +end + +!CHECK-LABEL: func.func @_QPf00 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f01(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .or. y .or. z +end + +!CHECK-LABEL: func.func @_QPf01 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f02(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .eqv. y .eqv. z +end + +!CHECK-LABEL: func.func @_QPf02 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f03(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .neqv. y .neqv. z +end + +!CHECK-LABEL: func.func @_QPf03 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f04(x, a, b, c) + implicit none + logical(kind=4) :: x + logical(kind=8) :: a, b, c + + !$omp atomic update + x = ((b .and. a) .and. x) .and. c +end + +!CHECK-LABEL: func.func @_QPf04 +!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref> +!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref> +!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1 +!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1 +!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref> +!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8> +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1 + +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: }