diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 53896072675ab..aba40025ca504 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -430,6 +430,28 @@ template std::optional ExtractCoarrayRef(const A &x) { } } +struct ExtractSubstringHelper { + template static std::optional visit(T &&) { + return std::nullopt; + } + + static std::optional visit(const Substring &e) { return e; } + + template + static std::optional visit(const Designator &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + template + static std::optional visit(const Expr &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } +}; + +template std::optional ExtractSubstring(const A &x) { + return ExtractSubstringHelper::visit(x); +} + // If an expression is simply a whole symbol data designator, // extract and return that symbol, else null. template const Symbol *UnwrapWholeSymbolDataRef(const A &x) { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 9987cd73fc767..6e45a939333d6 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -87,7 +87,7 @@ getSimdModifier(const omp::clause::Schedule &clause) { static void genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, + const omp::clause::Allocate &clause, llvm::SmallVectorImpl &allocatorOperands, llvm::SmallVectorImpl &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -95,21 +95,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional>( - ompAllocateClause.t); + const omp::ObjectList &objectList = std::get(clause.t); + const auto &modifier = + std::get>(clause.t); // If the allocate modifier is present, check if we only use the allocator // submodifier. ALIGN in this context is unimplemented const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); + modifier && + std::holds_alternative( + modifier->u); - if (allocateModifier && !onlyAllocator) { + if (modifier && !onlyAllocator) { TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } @@ -117,20 +114,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, // to list of allocators, otherwise, add default allocator to // list of allocators. if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + const auto &value = + std::get(modifier->u); + mlir::Value operand = + fir::getBase(converter.genExprValue(value.v, stmtCtx)); + allocatorOperands.append(objectList.size(), operand); } else { - allocatorOperand = firOpBuilder.createIntegerConstant( + mlir::Value operand = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + allocatorOperands.append(objectList.size(), operand); } - genObjectList(ompObjectList, converter, allocateOperands); + genObjectList(objectList, converter, allocateOperands); } static mlir::omp::ClauseProcBindKindAttr @@ -157,20 +151,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder, static mlir::omp::ClauseTaskDependAttr genDependKindAttr(fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::Depend *dependClause) { + const omp::clause::Depend &clause) { mlir::omp::ClauseTaskDepend pbKind; - switch ( - std::get( - std::get(dependClause->v.u) - .t) - .v) { - case Fortran::parser::OmpDependenceType::Type::In: + const auto &inOut = std::get(clause.u); + switch (std::get(inOut.t)) { + case omp::clause::Depend::Type::In: pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; - case Fortran::parser::OmpDependenceType::Type::Out: + case omp::clause::Depend::Type::Out: pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; - case Fortran::parser::OmpDependenceType::Type::Inout: + case omp::clause::Depend::Type::Inout: pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; default: @@ -181,45 +172,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder, pbKind); } -static mlir::Value getIfClauseOperand( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClause::If *ifClause, - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Location clauseLocation) { +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + const omp::clause::If &clause, + omp::clause::If::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { // Only consider the clause if it's intended for the given directive. - auto &directive = std::get< - std::optional>( - ifClause->v.t); + auto &directive = + std::get>(clause.t); if (directive && directive.value() != directiveName) return nullptr; Fortran::lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get(ifClause->v.t); mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + converter.genExprValue(std::get(clause.t), stmtCtx)); return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), ifVal); } static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpObjectList &useDeviceClause, + const omp::ObjectList &objects, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) { - genObjectList(useDeviceClause, converter, operands); + genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); useDeviceTypes.push_back(operand.getType()); useDeviceLocs.push_back(operand.getLoc()); } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } + for (const omp::Object &object : objects) + useDeviceSymbols.push_back(object.id()); } //===----------------------------------------------------------------------===// @@ -527,10 +514,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { bool ClauseProcessor::processAllocate( llvm::SmallVectorImpl &allocatorOperands, llvm::SmallVectorImpl &allocateOperands) const { - return findRepeatableClause( - [&](const ClauseTy::Allocate *allocateClause, + return findRepeatableClause( + [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, + genAllocateClause(converter, clause, allocatorOperands, allocateOperands); }); } @@ -547,12 +534,12 @@ bool ClauseProcessor::processCopyin() const { if (converter.isPresentShallowLookup(*sym)) converter.copyHostAssociateVar(*sym, copyAssignIP); }; - bool hasCopyin = findRepeatableClause( - [&](const ClauseTy::Copyin *copyinClause, + bool hasCopyin = findRepeatableClause( + [&](const omp::clause::Copyin &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const omp::Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.id(); + assert(sym && "Expecting symbol"); if (const auto *commonDetails = sym->detailsIf()) { for (const auto &mem : commonDetails->objects()) @@ -716,13 +703,11 @@ bool ClauseProcessor::processCopyPrivate( copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); }; - bool hasCopyPrivate = findRepeatableClause( - [&](const ClauseTy::Copyprivate *copyPrivateClause, + bool hasCopyPrivate = findRepeatableClause( + [&](const clause::Copyprivate &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = - copyPrivateClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.id(); if (const auto *commonDetails = sym->detailsIf()) { for (const auto &mem : commonDetails->objects()) @@ -741,38 +726,30 @@ bool ClauseProcessor::processDepend( llvm::SmallVectorImpl &dependOperands) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause( - [&](const ClauseTy::Depend *dependClause, + return findRepeatableClause( + [&](const omp::clause::Depend &clause, const Fortran::parser::CharBlock &) { - const std::list &depVal = - std::get>( - std::get( - dependClause->v.u) - .t); + assert(std::holds_alternative(clause.u) && + "Only InOut is handled at the moment"); + const auto &inOut = std::get(clause.u); + const auto &objects = std::get(inOut.t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if(&designator.u)) { - sym = name->symbol; - } else if (std::get_if>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); + genDependKindAttr(firOpBuilder, clause); + dependTypeOperands.append(objects.size(), dependTypeOperand); + + for (const omp::Object &object : objects) { + assert(object.ref() && "Expecting designator"); + + if (Fortran::evaluate::ExtractSubstring(*object.ref())) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + } else if (Fortran::evaluate::IsArrayElement(*object.ref())) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + + Fortran::semantics::Symbol *sym = object.id(); const mlir::Value variable = converter.getSymbolAddress(*sym); dependOperands.push_back(variable); } @@ -780,14 +757,14 @@ bool ClauseProcessor::processDepend( } bool ClauseProcessor::processIf( - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + omp::clause::If::DirectiveNameModifier directiveName, mlir::Value &result) const { bool found = false; - findRepeatableClause( - [&](const ClauseTy::If *ifClause, + findRepeatableClause( + [&](const omp::clause::If &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); - mlir::Value operand = getIfClauseOperand(converter, ifClause, + mlir::Value operand = getIfClauseOperand(converter, clause, directiveName, clauseLocation); // Assume that, at most, a single 'if' clause will be applicable to the // given directive. @@ -801,12 +778,11 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processLink( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Link *linkClause, - const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( - linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); }); } @@ -843,7 +819,7 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause( + return findRepeatableClause2( [&](const ClauseTy::Map *mapClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); @@ -935,43 +911,41 @@ bool ClauseProcessor::processReduction( llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl *reductionSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *reductionClause, + return findRepeatableClause( + [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols, - reductionSymbols); + rp.addReductionDecl(currentLocation, converter, clause, reductionVars, + reductionDeclSymbols, reductionSymbols); }); } bool ClauseProcessor::processSectionsReduction( mlir::Location currentLocation) const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); }); } bool ClauseProcessor::processTo( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + return findRepeatableClause( + [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) { // Case: declare target to(func, var1, var2)... - gatherFuncAndVarSyms(toClause->v, + gatherFuncAndVarSyms(clause.v, mlir::omp::DeclareTargetCaptureClause::to, result); }); } bool ClauseProcessor::processEnter( llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Enter *enterClause, + return findRepeatableClause( + [&](const omp::clause::Enter &clause, const Fortran::parser::CharBlock &) { // Case: declare target enter(func, var1, var2)... - gatherFuncAndVarSyms(enterClause->v, - mlir::omp::DeclareTargetCaptureClause::enter, - result); + gatherFuncAndVarSyms( + clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result); }); } @@ -981,11 +955,11 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::UseDeviceAddr *devAddrClause, + return findRepeatableClause( + [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause->v, operands, - useDeviceTypes, useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); }); } @@ -995,10 +969,10 @@ bool ClauseProcessor::processUseDevicePtr( llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSymbols) const { - return findRepeatableClause( - [&](const ClauseTy::UseDevicePtr *devPtrClause, + return findRepeatableClause( + [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index c87fc30c88bb9..3f6adcce8ae87 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -105,9 +105,8 @@ class ClauseProcessor { llvm::SmallVectorImpl &dependOperands) const; bool processEnter(llvm::SmallVectorImpl &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; + bool processIf(omp::clause::If::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl &result) const; @@ -178,6 +177,10 @@ class ClauseProcessor { /// if at least one instance was found. template bool findRepeatableClause( + std::function + callbackFn) const; + template + bool findRepeatableClause2( std::function callbackFn) const; @@ -195,7 +198,7 @@ template bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl &mapOperands) { - return findRepeatableClause( + return findRepeatableClause2( [&](const T *motionClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause( template bool ClauseProcessor::findRepeatableClause( + std::function + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.end(); + for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { + nextIt = findClause(it, endIt); + + if (nextIt != endIt) { + callbackFn(std::get(nextIt->u), nextIt->source); + found = true; + ++nextIt; + } + } + return found; +} + +template +bool ClauseProcessor::findRepeatableClause2( std::function callbackFn) const { bool found = false; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index a807814d0229d..70f232a4858e1 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -210,12 +210,6 @@ namespace clause { #undef EMPTY_CLASS #undef WRAPPER_CLASS -using DefinedOperator = tomp::clause::DefinedOperatorT; -using ProcedureDesignator = - tomp::clause::ProcedureDesignatorT; -using ReductionOperator = - tomp::clause::ReductionOperatorT; - DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp, semantics::SemanticsContext &semaCtx) { return std::visit( diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h index fdf45ec21e8e6..1d1a112aac3be 100644 --- a/flang/lib/Lower/OpenMP/Clauses.h +++ b/flang/lib/Lower/OpenMP/Clauses.h @@ -124,6 +124,12 @@ namespace clause { #undef EMPTY_CLASS #undef WRAPPER_CLASS +using DefinedOperator = tomp::clause::DefinedOperatorT; +using ProcedureDesignator = + tomp::clause::ProcedureDesignatorT; +using ReductionOperator = + tomp::clause::ReductionOperatorT; + // "Requires" clauses are handled early on, and the aggregated information // is stored in the Symbol details of modules, programs, and subprograms. // These clauses are still handled here to cover all alternatives in the diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 25bb4d9cff5d1..5d4db06ddafa9 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector reductionSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); cp.processDefault(); @@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, dependOperands; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processFinal(stmtCtx, finalClauseOperand); @@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector useDeviceSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::TargetData, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); @@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector mapOperands, dependOperands; llvm::SmallVector dependTypeOperands; - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; + clause::If::DirectiveNameModifier directiveName; // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. [[maybe_unused]] llvm::omp::Directive directive; if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; + directiveName = clause::If::DirectiveNameModifier::TargetEnterData; directive = llvm::omp::Directive::OMPD_target_enter_data; } else if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; + directiveName = clause::If::DirectiveNameModifier::TargetExitData; directive = llvm::omp::Directive::OMPD_target_exit_data; } else if constexpr (std::is_same_v) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate; + directiveName = clause::If::DirectiveNameModifier::TargetUpdate; directive = llvm::omp::Directive::OMPD_target_update; } else { return nullptr; @@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector mapSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLimitOperand); cp.processDepend(dependTypeOperands, dependOperands); @@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector reductionDeclSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processNumTeams(stmtCtx, numTeamsClauseOperand); @@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { + ObjectList objects{makeList(*objectList, semaCtx)}; // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, symbolAndClause); } else if (const auto *clauseList{ Fortran::parser::Unwrap( @@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, if (const auto &ompObjectList = std::get>( flushConstruct.t)) - genObjectList(*ompObjectList, converter, operandRange); + genObjectList2(*ompObjectList, converter, operandRange); const auto &memOrderClause = std::get>>( flushConstruct.t); @@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); cp.processReduction(loc, reductionVars, reductionDeclSymbols); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand); cp.processSimdlen(simdlenClauseOperand); cp.processSafelen(safelenClauseOperand); cp.processTODO clauses{makeList(clauseList, semaCtx)}; + + for (const Clause &clause : clauses) { if (const auto &reductionClause = - std::get_if(&clause.u)) { - const auto &redOperator{std::get( - reductionClause->v.t)}; - const auto &objectList{ - std::get(reductionClause->v.t)}; + std::get_if(&clause.u)) { + const auto &redOperator{ + std::get(reductionClause->t)}; + const auto &objects{std::get(reductionClause->t)}; if (const auto *reductionOp = - std::get_if(&redOperator.u)) { + std::get_if(&redOperator.u)) { const auto &intrinsicOp{ - std::get( + std::get( reductionOp->u)}; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case clause::DefinedOperator::IntrinsicOperator::Add: + case clause::DefinedOperator::IntrinsicOperator::Multiply: + case clause::DefinedOperator::IntrinsicOperator::AND: + case clause::DefinedOperator::IntrinsicOperator::EQV: + case clause::DefinedOperator::IntrinsicOperator::OR: + case clause::DefinedOperator::IntrinsicOperator::NEQV: break; default: continue; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast().getEleTy(); - if (!reductionType.isa()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + mlir::Type reductionType = + reductionVal.getType().cast().getEleTy(); + if (!reductionType.isa()) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (reductionType.isa()) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } } } } else if (const auto *reductionIntrinsic = - std::get_if( - &redOperator.u)) { + std::get_if(&redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) continue; ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { + assert(mlir::isa(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index e6a63dd4b939c..6dc467c4f69bc 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -30,9 +30,9 @@ namespace lower { namespace omp { ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( - const Fortran::parser::ProcedureDesignator &pd) { + const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch>( - ReductionProcessor::getRealName(pd).ToString()) + getRealName(pd.v.id()).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) @@ -44,21 +44,21 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( } ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); @@ -66,13 +66,11 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( } bool ReductionProcessor::supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) + const omp::clause::ProcedureDesignator &pd) { + Fortran::semantics::Symbol *sym = pd.v.id(); + if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch(getRealName(name).ToString()) + auto redType = llvm::StringSwitch(getRealName(sym).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) @@ -99,24 +97,24 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name, } std::string ReductionProcessor::getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty, bool isByRef) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty, + bool isByRef) { std::string reductionName; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; @@ -364,7 +362,7 @@ bool ReductionProcessor::doReductionByRef( void ReductionProcessor::addReductionDecl( mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl @@ -372,13 +370,12 @@ void ReductionProcessor::addReductionDecl( fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ - std::get(reduction.t)}; - const auto &objectList{std::get(reduction.t)}; + std::get(reduction.t)}; + const auto &objectList{std::get(reduction.t)}; - if (!std::holds_alternative( - redOperator.u)) { + if (!std::holds_alternative(redOperator.u)) { if (const auto *reductionIntrinsic = - std::get_if(&redOperator.u)) { + std::get_if(&redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { return; @@ -388,27 +385,23 @@ void ReductionProcessor::addReductionDecl( } } - // initial pass to collect all recuction vars so we can figure out if this + // initial pass to collect all reduction vars so we can figure out if this // should happen byref - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - reductionVars.push_back(symVal); - } - } + for (const Object &object : objectList) { + const Fortran::semantics::Symbol *symbol = object.id(); + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + reductionVars.push_back(symVal); } const bool isByRef = doReductionByRef(reductionVars); if (const auto &redDefinedOp = - std::get_if(&redOperator.u)) { + std::get_if(&redOperator.u)) { const auto &intrinsicOp{ - std::get( + std::get( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { @@ -424,73 +417,63 @@ void ReductionProcessor::addReductionDecl( "Reduction of some intrinsic operators is not supported"); break; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - auto redType = symVal.getType().cast(); - if (redType.getEleTy().isa()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type(), - isByRef), - redId, redType, currentLocation, isByRef); - else if (redType.getEleTy().isIntOrIndexOrFloat()) { - decl = createReductionDecl( - firOpBuilder, getReductionName(intrinsicOp, redType, isByRef), - redId, redType, currentLocation, isByRef); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } + + for (const Object &object : objectList) { + const Fortran::semantics::Symbol *symbol = object.id(); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + auto redType = symVal.getType().cast(); + if (redType.getEleTy().isa()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef), + redId, redType, currentLocation, isByRef); + else if (redType.getEleTy().isIntOrIndexOrFloat()) { + decl = createReductionDecl( + firOpBuilder, getReductionName(intrinsicOp, redType, isByRef), + redId, redType, currentLocation, isByRef); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); } } else if (const auto *reductionIntrinsic = - std::get_if( + std::get_if( &redOperator.u)) { if (ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - auto redType = symVal.getType().cast(); - assert(redType.getEleTy().isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType, isByRef), - redId, redType, currentLocation, isByRef); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } + for (const Object &object : objectList) { + const Fortran::semantics::Symbol *symbol = object.id(); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + auto redType = symVal.getType().cast(); + assert(redType.getEleTy().isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType, isByRef), + redId, redType, currentLocation, isByRef); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); } } } } const Fortran::semantics::SourceName -ReductionProcessor::getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); +ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { + return symbol->GetUltimate().name(); } -const Fortran::semantics::SourceName ReductionProcessor::getRealName( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); +const Fortran::semantics::SourceName +ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { + return getRealName(pd.v.id()); } int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 679580f2a3cac..ef6339407c135 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H +#include "Clauses.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Parser/parse-tree.h" @@ -58,19 +59,19 @@ class ReductionProcessor { }; static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd); + getReductionType(const omp::clause::ProcedureDesignator &pd); - static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp); + static ReductionIdentifier + getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp); - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd); + static bool + supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd); static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name); + getRealName(const Fortran::semantics::Symbol *symbol); static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd); + getRealName(const omp::clause::ProcedureDesignator &pd); static bool doReductionByRef(const llvm::SmallVectorImpl &reductionVars); @@ -78,9 +79,9 @@ class ReductionProcessor { static std::string getReductionName(llvm::StringRef name, mlir::Type ty, bool isByRef); - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty, bool isByRef); + static std::string + getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty, bool isByRef); /// This function returns the identity value of the operator \p /// reductionOpName. For example: @@ -119,7 +120,7 @@ class ReductionProcessor { static void addReductionDecl(mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 49517f62895df..fa4a51e338483 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Utils.h" +#include "Clauses.h" #include #include @@ -34,19 +35,33 @@ namespace Fortran { namespace lower { namespace omp { -void genObjectList(const Fortran::parser::OmpObjectList &objectList, +void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands) { + for (const Object &object : objects) { + const Fortran::semantics::Symbol *sym = object.id(); + assert(sym && "Expected Symbol"); + if (mlir::Value variable = converter.getSymbolAddress(*sym)) { + operands.push_back(variable); + } else if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), *sym); + } + } +} + +void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands) { auto addOperands = [&](Fortran::lower::SymbolRef sym) { const mlir::Value variable = converter.getSymbolAddress(sym); if (variable) { operands.push_back(variable); - } else { - if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } + } else if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), sym); } }; for (const Fortran::parser::OmpObject &ompObject : objectList.v) { @@ -56,24 +71,10 @@ void genObjectList(const Fortran::parser::OmpObjectList &objectList, } void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, + const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); - } + for (const Object &object : objects) + symbolAndClause.emplace_back(clause, *object.id()); } Fortran::semantics::Symbol * diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 76a15e8bcaab9..176ab2b5238a4 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -9,6 +9,7 @@ #ifndef FORTRAN_LOWER_OPENMPUTILS_H #define FORTRAN_LOWER_OPENMPUTILS_H +#include "Clauses.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Value.h" @@ -51,17 +52,20 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, bool isVal = false); void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, + const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl &symbolAndClause); Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); -void genObjectList(const Fortran::parser::OmpObjectList &objectList, +void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands); +void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands); + } // namespace omp } // namespace lower } // namespace Fortran