Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions flang/include/flang/Evaluate/match.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "flang/Common/Fortran-consts.h"
#include "flang/Common/visit.h"
#include "flang/Evaluate/expression.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/STLExtras.h"

#include <tuple>
Expand Down Expand Up @@ -86,9 +87,12 @@ template <typename T> struct TypePattern {
mutable const MatchType *ref{nullptr};
};

/// Matches one of the patterns provided as template arguments. All of these
/// patterns should have the same number of operands, i.e. they all should
/// try to match input expression with the same number of children, i.e.
/// Matches one of the patterns provided as template arguments.
/// Upon creation of an AnyOfPattern object with some arguments, say args,
/// each of the pattern objects will be created using args as arguments to
/// the constructor. This means that each of the patterns should be
/// constructible from args, in particular all patterns should take the same
/// number of inputs. So, for example,
/// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
/// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
template <typename... Patterns> struct AnyOfPattern {
Expand Down Expand Up @@ -178,16 +182,67 @@ struct OperationPattern : public TypePattern<OpType> {
};

template <typename OpType, typename... Ops>
OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
OperationPattern(const Ops &..., llvm::type_identity<OpType>)
-> OperationPattern<OpType, Ops...>;

// Encode the actual operator in the type, so that the class is constructible
// only from operand patterns. This will make it usable in AnyOfPattern.
template <common::LogicalOperator Operator, typename ValType, typename... Ops>
struct LogicalOperationPattern
: public OperationPattern<LogicalOperation<ValType::kind>, Ops...> {
using Base = OperationPattern<LogicalOperation<ValType::kind>, Ops...>;
static constexpr common::LogicalOperator opCode{Operator};

private:
template <int K> bool matchOp(const LogicalOperation<K> &op) const {
if constexpr (ValType::kind == K) {
return op.logicalOperator == opCode;
}
return false;
}
template <typename U> bool matchOp(const U &) const { return false; }

public:
LogicalOperationPattern(const Ops &...ops, llvm::type_identity<ValType> = {})
: Base(ops...) {}

template <typename T> bool match(const evaluate::Expr<T> &input) const {
// All logical operations (for a given type T) have the same operation
// type (LogicalOperation<T::kind>), so the type-based matching will not
// be able to tell specific operations from one another.
// Check the operation code first, if that matches then use the the
// base class's match.
if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) {
return Base::match(input);
} else {
return false;
}
}

template <typename U> bool match(const U &input) const { //
return false;
}
};

// No deduction guide for LogicalOperationPattern, since the "Operator"
// parameter cannot be deduced from the constructor arguments.

// Namespace-level definitions

template <typename T> using Expr = ExprPattern<T>;

template <typename OpType, typename... Ops>
using Op = OperationPattern<OpType, Ops...>;

template <common::LogicalOperator Operator, typename ValType, typename... Ops>
using LogicalOp = LogicalOperationPattern<Operator, ValType, Ops...>;

template <common::LogicalOperator Operator, typename Type, typename Op0,
typename Op1>
LogicalOp<Operator, Type, Op0, Op1> logical(const Op0 &op0, const Op1 &op1) {
return LogicalOp<Operator, Type, Op0, Op1>(op0, op1);
}

template <typename Pattern, typename Input>
bool match(const Pattern &pattern, const Input &input) {
return pattern.match(input);
Expand Down
82 changes: 56 additions & 26 deletions flang/lib/Semantics/check-omp-atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ template <common::TypeCategory C, int K>
struct IsIntegral<evaluate::Type<C, K>> {
static constexpr bool value{//
C == common::TypeCategory::Integer ||
C == common::TypeCategory::Unsigned ||
C == common::TypeCategory::Logical};
C == common::TypeCategory::Unsigned};
};

template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
Expand All @@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
template <typename T>
constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};

template <typename...> struct IsLogical {
static constexpr bool value{false};
};

template <common::TypeCategory C, int K>
struct IsLogical<evaluate::Type<C, K>> {
static constexpr bool value{C == common::TypeCategory::Logical};
};

template <typename T> constexpr bool is_logical_v{IsLogical<T>::value};

template <typename T, typename Op0, typename Op1>
using ReassocOpBase = evaluate::match::AnyOfPattern< //
evaluate::match::Add<T, Op0, Op1>, //
evaluate::match::Mul<T, Op0, Op1>>;
evaluate::match::Mul<T, Op0, Op1>, //
evaluate::match::LogicalOp<common::LogicalOperator::And, T, Op0, Op1>,
evaluate::match::LogicalOp<common::LogicalOperator::Or, T, Op0, Op1>,
evaluate::match::LogicalOp<common::LogicalOperator::Eqv, T, Op0, Op1>,
evaluate::match::LogicalOp<common::LogicalOperator::Neqv, T, Op0, Op1>>;

template <typename T, typename Op0, typename Op1>
struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
Expand All @@ -110,16 +124,16 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
// Try to find cases where the input expression is of the form
// (1) (a . b) . c, or
// (2) a . (b . c),
// where . denotes an associative operation (currently + or *), and a, b, c
// are some subexpresions.
// where . denotes an associative operation, and a, b, c are some
// subexpresions.
// If one of the operands in the nested operation is the atomic variable
// (with some possible type conversions applied to it), bring it to the
// top-level operation, and move the top-level operand into the nested
// operation.
// For example, assuming x is the atomic variable:
// (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b.
template <typename T, typename U,
typename = std::enable_if_t<is_numeric_v<T>>>
typename = std::enable_if_t<is_numeric_v<T> || is_logical_v<T>>>
evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
if constexpr (is_floating_point_v<T>) {
if (!context_.langOptions().AssociativeMath) {
Expand All @@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
// some order) from the example above.
evaluate::match::Expr<T> sub[3];
auto inner{reassocOp<T>(sub[0], sub[1])};
auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
auto outer1{reassocOp<T>(inner, sub[2])}; // inner . something
auto outer2{reassocOp<T>(sub[2], inner)}; // something . inner
#if !defined(__clang__) && !defined(_MSC_VER) && \
(__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
// If GCC version < 8.5, use this definition. For the other definition
Expand Down Expand Up @@ -167,37 +181,53 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
}
return common::visit(
[&](auto &&s) {
using Expr = evaluate::Expr<T>;
using TypeS = llvm::remove_cvref_t<decltype(s)>;
// This visitor has to be semantically correct for all possible
// types of s even though at runtime s will only be one of the
// matched types.
// Limit the construction to the operation types that we tried
// to match (otherwise TypeS(op1, op2) would fail for non-binary
// operations).
if constexpr (common::HasMember<TypeS, MatchTypes>) {
Expr atom{*sub[atomIdx].ref};
Expr op1{*sub[(atomIdx + 1) % 3].ref};
Expr op2{*sub[(atomIdx + 2) % 3].ref};
return Expr(
TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
} else {
return Expr(TypeS(s));
}
// Build the new expression from the matched components.
return Reconstruct<T, MatchTypes>(s, *sub[atomIdx].ref,
*sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref);
},
evaluate::match::deparen(x).u);
}
return Id::operator()(std::move(x), u);
}

template <typename T, typename U,
typename = std::enable_if_t<!is_numeric_v<T>>>
typename = std::enable_if_t<!is_numeric_v<T> && !is_logical_v<T>>>
evaluate::Expr<T> operator()(
evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
return Id::operator()(std::move(x), u);
}

private:
template <typename T, typename MatchTypes, typename S>
evaluate::Expr<T> Reconstruct(const S &op, evaluate::Expr<T> atom,
evaluate::Expr<T> op1, evaluate::Expr<T> op2) {
using TypeS = llvm::remove_cvref_t<decltype(op)>;
// This function has to be semantically correct for all possible types
// of S even though at runtime s will only be one of the matched types.
// Limit the construction to the operation types that we tried to match
// (otherwise TypeS(op1, op2) would fail for non-binary operations).
if constexpr (!common::HasMember<TypeS, MatchTypes>) {
return evaluate::Expr<T>(TypeS(op));
} else if constexpr (is_logical_v<T>) {
constexpr int K{T::kind};
if constexpr (std::is_same_v<TypeS, evaluate::LogicalOperation<K>>) {
// Logical operators take an extra argument in their constructor,
// so they need their own reconstruction code.
common::LogicalOperator opCode{op.logicalOperator};
return evaluate::Expr<T>(TypeS( //
opCode, std::move(atom),
evaluate::Expr<T>(TypeS( //
opCode, std::move(op1), std::move(op2)))));
}
} else {
// Generic reconstruction.
return evaluate::Expr<T>(TypeS( //
std::move(atom),
evaluate::Expr<T>(TypeS( //
std::move(op1), std::move(op2)))));
}
}

template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
}
Expand Down
137 changes: 137 additions & 0 deletions flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s

subroutine f00(x, y, z)
implicit none
logical :: x, y, z

!$omp atomic update
x = x .and. y .and. z
end

!CHECK-LABEL: func.func @_QPf00
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
!CHECK: %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4>
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
!CHECK: }


subroutine f01(x, y, z)
implicit none
logical :: x, y, z

!$omp atomic update
x = x .or. y .or. z
end

!CHECK-LABEL: func.func @_QPf01
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
!CHECK: %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4>
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
!CHECK: }


subroutine f02(x, y, z)
implicit none
logical :: x, y, z

!$omp atomic update
x = x .eqv. y .eqv. z
end

!CHECK-LABEL: func.func @_QPf02
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
!CHECK: %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4>
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
!CHECK: }


subroutine f03(x, y, z)
implicit none
logical :: x, y, z

!$omp atomic update
x = x .neqv. y .neqv. z
end

!CHECK-LABEL: func.func @_QPf03
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
!CHECK: %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1
!CHECK: %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4>
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
!CHECK: }


subroutine f04(x, a, b, c)
implicit none
logical(kind=4) :: x
logical(kind=8) :: a, b, c

!$omp atomic update
x = ((b .and. a) .and. x) .and. c
end

!CHECK-LABEL: func.func @_QPf04
!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<!fir.logical<8>>
!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<!fir.logical<8>>
!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1
!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1
!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1
!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref<!fir.logical<8>>
!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1
!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1
!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
!CHECK: %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8>
!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1
!CHECK: %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1

!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4>
!CHECK: omp.yield(%[[RET]] : !fir.logical<4>)
!CHECK: }