Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Allow lowering of sub-expressions to be overridden #69944

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
class StatementContext;

using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;

//===----------------------------------------------------------------------===//
// AbstractConverter interface
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -90,6 +92,14 @@ class AbstractConverter {
/// added or replaced at the inner-most level of the local symbol map.
virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;

/// Override lowering of expression with pre-lowered values.
/// Associate mlir::Value to evaluate::Expr. All subsequent call to
/// genExprXXX() will replace any occurrence of an overridden
/// expression in the expression tree by the pre-lowered values.
virtual void overrideExprValues(const ExprToValueMap *) = 0;
void resetExprOverrides() { overrideExprValues(nullptr); }
virtual const ExprToValueMap *getExprOverrides() = 0;

/// Get the label set associated with a symbol.
virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;

Expand Down
11 changes: 11 additions & 0 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
addSymbol(sym, exval, /*forced=*/true);
}

void
overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
exprValueOverrides = map;
}

const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
return exprValueOverrides;
}

bool lookupLabelSet(Fortran::lower::SymbolRef sym,
Fortran::lower::pft::LabelSet &labelSet) override final {
Fortran::lower::pft::FunctionLikeUnit &owningProc =
Expand Down Expand Up @@ -4890,6 +4899,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Whether an OpenMP target region or declare target function/subroutine
/// intended for device offloading has been detected
bool ompDeviceCodeFound = false;

const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
};

} // namespace
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2963,8 +2963,21 @@ class ScalarExprLowering {
return asArray(x);
}

template <typename A>
mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
if (const Fortran::lower::ExprToValueMap *map =
converter.getExprOverrides()) {
Fortran::lower::SomeExpr someExpr = toEvExpr(x);
if (auto match = map->find(&someExpr); match != map->end())
return match->second;
}
return mlir::Value{};
}

template <typename A>
ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
if (mlir::Value val = getIfOverridenExpr(x))
return val;
// Whole array symbols or components, and results of transformational
// functions already have a storage and the scalar expression lowering path
// is used to not create a new temporary storage.
Expand All @@ -2978,6 +2991,8 @@ class ScalarExprLowering {
}
template <typename A>
ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
if (mlir::Value val = getIfOverridenExpr(x))
return val;
if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
inInitializer)
return std::visit([&](const auto &e) { return genval(e); }, x.u);
Expand All @@ -2987,6 +3002,8 @@ class ScalarExprLowering {
template <int KIND>
ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
Fortran::common::TypeCategory::Logical, KIND>> &exp) {
if (mlir::Value val = getIfOverridenExpr(exp))
return val;
return std::visit([&](const auto &e) { return genval(e); }, exp.u);
}

Expand Down
11 changes: 11 additions & 0 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,17 @@ class HlfirBuilder {

template <typename T>
hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
if (const Fortran::lower::ExprToValueMap *map =
getConverter().getExprOverrides()) {
if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
if (auto match = map->find(&expr); match != map->end())
return hlfir::EntityWithAttributes{match->second};
} else {
Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
if (auto match = map->find(&someExpr); match != map->end())
return hlfir::EntityWithAttributes{match->second};
}
}
return std::visit([&](const auto &x) { return gen(x); }, expr.u);
}

Expand Down
154 changes: 55 additions & 99 deletions flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
const Fortran::parser::Expr &assignmentStmtExpr,
[[maybe_unused]] const AtomicListT *leftHandClauseList,
[[maybe_unused]] const AtomicListT *rightHandClauseList) {
[[maybe_unused]] const AtomicListT *rightHandClauseList,
mlir::Operation *atomicCaptureOp = nullptr) {
// Generate `omp.atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();

const auto *varDesignator =
std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
&assignmentStmtVariable.u);
assert(varDesignator && "Variable designator for atomic update assignment "
"statement does not exist");
const Fortran::parser::Name *name =
Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
if (!name)
TODO(converter.getCurrentLocation(),
"Array references as atomic update variable");
assert(name && name->symbol &&
"No symbol attached to atomic update variable");
if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
converter.bindSymbol(*name->symbol, lhsAddr);

// Lowering is in two steps :
// subroutine sb
// integer :: a, b
// !$omp atomic update
// a = a + b
// end subroutine
//
// 1. Lower to scf.execute_region_op
//
// func.func @_QPsb() {
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
// %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
// %2 = scf.execute_region -> i32 {
// %3 = fir.load %0 : !fir.ref<i32>
// %4 = fir.load %1 : !fir.ref<i32>
// %5 = arith.addi %3, %4 : i32
// scf.yield %5 : i32
// }
// return
// }
auto tempOp =
firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
firOpBuilder.createBlock(&tempOp.getRegion());
mlir::Block &block = tempOp.getRegion().back();
firOpBuilder.setInsertionPointToEnd(&block);
Fortran::lower::StatementContext stmtCtx;
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
mlir::Value convertResult =
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
// Insert the terminator: YieldOp.
firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
firOpBuilder.setInsertionPointToStart(&block);

// 2. Create the omp.atomic.update Operation using the Operations in the
// temporary scf.execute_region Operation.
// Create the omp.atomic.update Operation
//
// func.func @_QPsb() {
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
Expand All @@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
// }
// return
// }
mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
updateVar = decl.getBase();

firOpBuilder.setInsertionPointAfter(tempOp);
Fortran::lower::ExprToValueMap exprValueOverrides;
// Lower any non atomic sub-expression before the atomic operation, and
// map its lowered value to the semantic representation.
const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
std::visit(
[&](const auto &op) -> void {
using T = std::decay_t<decltype(op)>;
if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
T>::value) {
const auto &exprLeft{std::get<0>(op.t)};
const auto &exprRight{std::get<1>(op.t)};
if (exprLeft.value().source == assignmentStmtVariable.GetSource())
nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
else
nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we can do something similar here for the allowed intrinsics too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can indeed go through the intrinsic function arguments here too yes, @NimishMishra, I do not plan to work on this, so I am happy if you do in a later patch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

},
assignmentStmtExpr.u);
StatementContext nonAtomicStmtCtx;
if (nonAtomicSubExpr) {
// Generate non atomic part before all the atomic operations.
auto insertionPoint = firOpBuilder.saveInsertionPoint();
if (atomicCaptureOp)
firOpBuilder.setInsertionPoint(atomicCaptureOp);
mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
if (atomicCaptureOp)
firOpBuilder.restoreInsertionPoint(insertionPoint);
}

mlir::Operation *atomicUpdateOp = nullptr;
if constexpr (std::is_same<AtomicListT,
Expand All @@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
hint, memoryOrder);
atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
currentLocation, updateVar, hint, memoryOrder);
currentLocation, lhsAddr, hint, memoryOrder);
} else {
atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
currentLocation, updateVar);
currentLocation, lhsAddr);
}

llvm::SmallVector<mlir::Type> varTys = {varType};
Expand All @@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
mlir::Value val =
fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));

llvm::SmallVector<mlir::Operation *> ops;
for (mlir::Operation &op : tempOp.getRegion().getOps())
ops.push_back(&op);

// SCF Yield is converted to OMP Yield. All other operations are copied
for (mlir::Operation *op : ops) {
if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
firOpBuilder.setInsertionPointToEnd(
&atomicUpdateOp->getRegion(0).front());
if constexpr (std::is_same<AtomicListT,
Fortran::parser::OmpAtomicClauseList>()) {
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
y.getResults());
} else {
firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
y.getResults());
}
op->erase();
exprValueOverrides.try_emplace(
Fortran::semantics::GetExpr(assignmentStmtVariable), val);
{
// statement context inside the atomic block.
converter.overrideExprValues(&exprValueOverrides);
Fortran::lower::StatementContext atomicStmtCtx;
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
mlir::Value convertResult =
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
if constexpr (std::is_same<AtomicListT,
Fortran::parser::OmpAtomicClauseList>()) {
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
} else {
op->remove();
atomicUpdateOp->getRegion(0).front().push_back(op);
firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
}
converter.resetExprOverrides();
}

// Remove the load and replace all uses of load with the block argument
for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
if (y && y.getMemref() == updateVar)
y.getRes().replaceAllUsesWith(val);
}

tempOp.erase();
firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
}

/// Processes an atomic construct with write clause.
Expand Down Expand Up @@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
mlir::Type varType =
fir::getBase(
converter.genExprValue(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
.getType();
mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
leftHandClauseList, rightHandClauseList);
Expand All @@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
mlir::Type varType =
fir::getBase(
converter.genExprValue(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
.getType();
mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
// If atomic-clause is not present on the construct, the behaviour is as if
// the update clause is specified (for both OpenMP and OpenACC).
genOmpAccAtomicUpdateStatement<AtomicListT>(
Expand Down Expand Up @@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr);
/*rightHandClauseList=*/nullptr, atomicCaptureOp);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
const Fortran::semantics::SomeExpr &fromExpr =
Expand Down Expand Up @@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
genOmpAccAtomicUpdateStatement<AtomicListT>(
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr);
/*rightHandClauseList=*/nullptr, atomicCaptureOp);
}
firOpBuilder.setInsertionPointToEnd(&block);
if constexpr (std::is_same<AtomicListT,
Expand Down
10 changes: 5 additions & 5 deletions flang/test/Lower/OpenACC/acc-atomic-capture.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ program acc_atomic_capture_test

!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
Expand All @@ -23,10 +23,10 @@ program acc_atomic_capture_test
!$acc end atomic


!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
!CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
Expand Down Expand Up @@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
!CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
!CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
!CHECK: acc.atomic.capture {
!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
!CHECK: acc.yield %[[result]] : i32
!CHECK: }
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ subroutine sb
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG_X:.*]]: i32):
!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
!CHECK: acc.yield %[[X_UPDATE_VAL]] : i32
!CHECK: }
Expand Down