Skip to content

Commit

Permalink
[flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (
Browse files Browse the repository at this point in the history
#81623)

…essor

Rename `findRepeatableClause` to `findRepeatableClause2`, and make the
new `findRepeatableClause` operate on new `omp::Clause` objects.

Leave `Map` unchanged, because it will require more changes for it to
work.

[Clause representation 3/6]
  • Loading branch information
kparzysz committed Mar 15, 2024
1 parent 03bad4b commit 63e70c0
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 366 deletions.
22 changes: 22 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,28 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}

struct ExtractSubstringHelper {
template <typename T> static std::optional<Substring> visit(T &&) {
return std::nullopt;
}

static std::optional<Substring> visit(const Substring &e) { return e; }

template <typename T>
static std::optional<Substring> visit(const Designator<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}

template <typename T>
static std::optional<Substring> visit(const Expr<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}
};

template <typename A> std::optional<Substring> ExtractSubstring(const A &x) {
return ExtractSubstringHelper::visit(x);
}

// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
Expand Down
218 changes: 96 additions & 122 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Large diffs are not rendered by default.

29 changes: 25 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

Expand Down Expand Up @@ -178,6 +177,10 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const;
template <typename T>
bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const;

Expand All @@ -195,7 +198,7 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
return findRepeatableClause<T>(
return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down Expand Up @@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(

template <typename T>
bool ClauseProcessor::findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
ClauseIterator nextIt, endIt = clauses.end();
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
nextIt = findClause<T>(it, endIt);

if (nextIt != endIt) {
callbackFn(std::get<T>(nextIt->u), nextIt->source);
found = true;
++nextIt;
}
}
return found;
}

template <typename T>
bool ClauseProcessor::findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
Expand Down
6 changes: 0 additions & 6 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,6 @@ namespace clause {
#undef EMPTY_CLASS
#undef WRAPPER_CLASS

using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
using ProcedureDesignator =
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
using ReductionOperator =
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;

DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp,
semantics::SemanticsContext &semaCtx) {
return std::visit(
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ namespace clause {
#undef EMPTY_CLASS
#undef WRAPPER_CLASS

using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
using ProcedureDesignator =
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
using ReductionOperator =
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;

// "Requires" clauses are handled early on, and the aggregated information
// is stored in the Symbol details of modules, programs, and subprograms.
// These clauses are still handled here to cover all alternatives in the
Expand Down
186 changes: 86 additions & 100 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
cp.processDefault();
Expand Down Expand Up @@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
dependOperands;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processFinal(stmtCtx, finalClauseOperand);
Expand Down Expand Up @@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::TargetData, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
Expand Down Expand Up @@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
clause::If::DirectiveNameModifier directiveName;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
directive = llvm::omp::Directive::OMPD_target_enter_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
directiveName = clause::If::DirectiveNameModifier::TargetExitData;
directive = llvm::omp::Directive::OMPD_target_exit_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
directive = llvm::omp::Directive::OMPD_target_update;
} else {
return nullptr;
Expand Down Expand Up @@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processDepend(dependTypeOperands, dependOperands);
Expand Down Expand Up @@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
Expand Down Expand Up @@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(

if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeList(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
Expand Down Expand Up @@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
if (const auto &ompObjectList =
std::get<std::optional<Fortran::parser::OmpObjectList>>(
flushConstruct.t))
genObjectList(*ompObjectList, converter, operandRange);
genObjectList2(*ompObjectList, converter, operandRange);
const auto &memOrderClause =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
Expand Down Expand Up @@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
cp.processTODO<Fortran::parser::OmpClause::Aligned,
Expand Down Expand Up @@ -2419,106 +2411,100 @@ void Fortran::lower::genOpenMPReduction(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const Fortran::parser::OmpClause &clause : clauseList.v) {
List<Clause> clauses{makeList(clauseList, semaCtx)};

for (const Clause &clause : clauses) {
if (const auto &reductionClause =
std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
reductionClause->v.t)};
const auto &objectList{
std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
std::get_if<clause::Reduction>(&clause.u)) {
const auto &redOperator{
std::get<clause::ReductionOperator>(reductionClause->t)};
const auto &objects{std::get<ObjectList>(reductionClause->t)};
if (const auto *reductionOp =
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
std::get_if<clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
std::get<clause::DefinedOperator::IntrinsicOperator>(
reductionOp->u)};

switch (intrinsicOp) {
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
case clause::DefinedOperator::IntrinsicOperator::Add:
case clause::DefinedOperator::IntrinsicOperator::Multiply:
case clause::DefinedOperator::IntrinsicOperator::AND:
case clause::DefinedOperator::IntrinsicOperator::EQV:
case clause::DefinedOperator::IntrinsicOperator::OR:
case clause::DefinedOperator::IntrinsicOperator::NEQV:
break;
default:
continue;
}
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
mlir::Type reductionType =
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!reductionType.isa<fir::LogicalType>()) {
if (!reductionType.isIntOrIndexOrFloat())
continue;
}
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
if (reductionType.isa<fir::LogicalType>()) {
mlir::Operation *reductionOp = findReductionChain(loadVal);
fir::ConvertOp convertOp =
getConvertFromReductionOp(reductionOp, loadVal);
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal, &convertOp);
removeStoreOp(reductionOp, reductionVal);
} else if (mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal)) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
for (const Object &object : objects) {
if (const Fortran::semantics::Symbol *symbol = object.id()) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
mlir::Type reductionType =
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!reductionType.isa<fir::LogicalType>()) {
if (!reductionType.isIntOrIndexOrFloat())
continue;
}
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
if (auto loadOp =
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
if (reductionType.isa<fir::LogicalType>()) {
mlir::Operation *reductionOp = findReductionChain(loadVal);
fir::ConvertOp convertOp =
getConvertFromReductionOp(reductionOp, loadVal);
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal, &convertOp);
removeStoreOp(reductionOp, reductionVal);
} else if (mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal)) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
}
}
}
}
} else if (const auto *reductionIntrinsic =
std::get_if<Fortran::parser::ProcedureDesignator>(
&redOperator.u)) {
std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic))
continue;
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
for (const mlir::OpOperand &reductionValUse :
reductionVal.getUses()) {
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
// Max is lowered as a compare -> select.
// Match the pattern here.
mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal);
if (reductionOp == nullptr)
continue;

if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
redId == ReductionProcessor::ReductionIdentifier::MIN) {
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
"Selection Op not found in reduction intrinsic");
mlir::Operation *compareOp =
getCompareFromReductionOp(reductionOp, loadVal);
updateReduction(compareOp, firOpBuilder, loadVal,
reductionVal);
}
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
redId == ReductionProcessor::ReductionIdentifier::IAND) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
for (const Object &object : objects) {
if (const Fortran::semantics::Symbol *symbol = object.id()) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
for (const mlir::OpOperand &reductionValUse :
reductionVal.getUses()) {
if (auto loadOp =
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
// Max is lowered as a compare -> select.
// Match the pattern here.
mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal);
if (reductionOp == nullptr)
continue;

if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
redId == ReductionProcessor::ReductionIdentifier::MIN) {
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
"Selection Op not found in reduction intrinsic");
mlir::Operation *compareOp =
getCompareFromReductionOp(reductionOp, loadVal);
updateReduction(compareOp, firOpBuilder, loadVal,
reductionVal);
}
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
redId == ReductionProcessor::ReductionIdentifier::IAND) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
}
}
Expand Down

0 comments on commit 63e70c0

Please sign in to comment.