diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 9a02d3b3909ef..ad4cffc707535 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2266,31 +2266,68 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter, return storeOp; } -struct CreateBodyOfOpInfo { +struct OpWithBodyGenInfo { + /// A type for a code-gen callback function. This takes as argument the op for + /// which the code is being generated and returns the arguments of the op's + /// region. + using GenOMPRegionEntryCBFn = + std::function( + mlir::Operation *)>; + + OpWithBodyGenInfo(Fortran::lower::AbstractConverter &converter, + mlir::Location loc, Fortran::lower::pft::Evaluation &eval) + : converter(converter), loc(loc), eval(eval) {} + + OpWithBodyGenInfo &setGenNested(bool value) { + genNested = value; + return *this; + } + + OpWithBodyGenInfo &setOuterCombined(bool value) { + outerCombined = value; + return *this; + } + + OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) { + clauses = value; + return *this; + } + + OpWithBodyGenInfo &setDataSharingProcessor(DataSharingProcessor *value) { + dsp = value; + return *this; + } + + OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) { + genRegionEntryCB = value; + return *this; + } + + /// [inout] converter to use for the clauses. Fortran::lower::AbstractConverter &converter; - mlir::Location &loc; + /// [in] location in source code. + mlir::Location loc; + /// [in] current PFT node/evaluation. Fortran::lower::pft::Evaluation &eval; + /// [in] whether to generate FIR for nested evaluations bool genNested = true; - const Fortran::parser::OmpClauseList *clauses = nullptr; - const llvm::SmallVector &args = {}; + /// [in] is this an outer operation - prevents privatization. bool outerCombined = false; + /// [in] list of clauses to process. + const Fortran::parser::OmpClauseList *clauses = nullptr; + /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; + /// [in] if provided, emits the op's region entry. Otherwise, an emtpy block + /// is created in the region. + GenOMPRegionEntryCBFn genRegionEntryCB = nullptr; }; /// Create the body (block) for an OpenMP Operation. /// -/// \param [in] op - the operation the body belongs to. -/// \param [inout] converter - converter to use for the clauses. -/// \param [in] loc - location in source code. -/// \param [in] eval - current PFT node/evaluation. -/// \param [in] genNested - whether to generate FIR for nested evaluations -/// \oaran [in] clauses - list of clauses to process. -/// \param [in] args - block arguments (induction variable[s]) for the -//// region. -/// \param [in] outerCombined - is this an outer operation - prevents -/// privatization. +/// \param [in] op - the operation the body belongs to. +/// \param [in] info - options controlling code-gen for the construction. template -static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) { +static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) { fir::FirOpBuilder &firOpBuilder = info.converter.getFirOpBuilder(); auto insertMarker = [](fir::FirOpBuilder &builder) { @@ -2303,28 +2340,15 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) { // argument. Also update the symbol's address with the mlir argument value. // e.g. For loops the argument is the induction variable. And all further // uses of the induction variable should use this mlir value. - if (info.args.size()) { - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : info.args) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(info.converter, loopVarTypeSize); - llvm::SmallVector tiv(info.args.size(), loopVarType); - llvm::SmallVector locs(info.args.size(), info.loc); - firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - mlir::Operation *storeOp = nullptr; - for (auto [argIndex, argSymbol] : llvm::enumerate(info.args)) { - mlir::Value indexVal = - fir::getBase(op.getRegion().front().getArgument(argIndex)); - storeOp = createAndSetPrivatizedLoopVar(info.converter, info.loc, - indexVal, argSymbol); + auto regionArgs = + [&]() -> llvm::SmallVector { + if (info.genRegionEntryCB != nullptr) { + return info.genRegionEntryCB(op); } - firOpBuilder.setInsertionPointAfter(storeOp); - } else { - firOpBuilder.createBlock(&op.getRegion()); - } + firOpBuilder.createBlock(&op.getRegion()); + return {}; + }(); // Mark the earliest insertion point. mlir::Operation *marker = insertMarker(firOpBuilder); @@ -2421,8 +2445,8 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) { assert(tempDsp.has_value()); tempDsp->processStep2(op, isLoop); } else { - if (isLoop && info.args.size() > 0) - info.dsp->setLoopIV(info.converter.getSymbolAddress(*info.args[0])); + if (isLoop && regionArgs.size() > 0) + info.dsp->setLoopIV(info.converter.getSymbolAddress(*regionArgs[0])); info.dsp->processStep2(op, isLoop); } } @@ -2497,24 +2521,11 @@ static void genBodyOfTargetDataOp( genNestedEvaluations(converter, eval); } -struct GenOpWithBodyInfo { - Fortran::lower::AbstractConverter &converter; - Fortran::lower::pft::Evaluation &eval; - bool genNested = false; - mlir::Location currentLocation; - bool outerCombined = false; - const Fortran::parser::OmpClauseList *clauseList = nullptr; -}; - template -static OpTy genOpWithBody(GenOpWithBodyInfo info, Args &&...args) { +static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { auto op = info.converter.getFirOpBuilder().create( - info.currentLocation, std::forward(args)...); - createBodyOfOp( - op, {info.converter, info.currentLocation, info.eval, info.genNested, - info.clauseList, - /*args*/ llvm::SmallVector{}, - info.outerCombined}); + info.loc, std::forward(args)...); + createBodyOfOp(op, info); return op; } @@ -2523,7 +2534,8 @@ genMasterOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation) { return genOpWithBody( - {converter, eval, genNested, currentLocation}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested), /*resultTypes=*/mlir::TypeRange()); } @@ -2532,7 +2544,8 @@ genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation) { return genOpWithBody( - {converter, eval, genNested, currentLocation}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested), /*simd=*/false); } @@ -2560,7 +2573,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); return genOpWithBody( - {converter, eval, genNested, currentLocation, outerCombined, &clauseList}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList), /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, numThreadsClauseOperand, allocateOperands, allocatorOperands, reductionVars, @@ -2579,8 +2595,9 @@ genSectionOp(Fortran::lower::AbstractConverter &converter, // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. return genOpWithBody( - {converter, eval, genNested, currentLocation, - /*outerCombined=*/false, §ionsClauseList}); + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setClauses(§ionsClauseList)); } static mlir::omp::SingleOp @@ -2600,8 +2617,9 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr); return genOpWithBody( - {converter, eval, genNested, currentLocation, - /*outerCombined=*/false, &beginClauseList}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setClauses(&beginClauseList), allocateOperands, allocatorOperands, nowaitAttr); } @@ -2633,8 +2651,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, currentLocation, llvm::omp::Directive::OMPD_task); return genOpWithBody( - {converter, eval, genNested, currentLocation, - /*outerCombined=*/false, &clauseList}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setClauses(&clauseList), ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, /*in_reduction_vars=*/mlir::ValueRange(), /*in_reductions=*/nullptr, priorityClauseOperand, @@ -2656,8 +2675,9 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter, cp.processTODO( currentLocation, llvm::omp::Directive::OMPD_taskgroup); return genOpWithBody( - {converter, eval, genNested, currentLocation, - /*outerCombined=*/false, &clauseList}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setClauses(&clauseList), /*task_reduction_vars=*/mlir::ValueRange(), /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); } @@ -3040,7 +3060,10 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, currentLocation, llvm::omp::Directive::OMPD_teams); return genOpWithBody( - {converter, eval, genNested, currentLocation, outerCombined, &clauseList}, + OpWithBodyGenInfo(converter, currentLocation, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList), /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand, allocateOperands, allocatorOperands, reductionVars, @@ -3237,6 +3260,33 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter, } } +static llvm::SmallVector +genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + const llvm::SmallVector &args) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + auto ®ion = op->getRegion(0); + + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : args) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + llvm::SmallVector tiv(args.size(), loopVarType); + llvm::SmallVector locs(args.size(), loc); + firOpBuilder.createBlock(®ion, {}, tiv, locs); + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + mlir::Operation *storeOp = nullptr; + for (auto [argIndex, argSymbol] : llvm::enumerate(args)) { + mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex)); + storeOp = + createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); + } + firOpBuilder.setInsertionPointAfter(storeOp); + + return args; +} + static void createSimdLoop(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -3284,10 +3334,16 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(loopOpClauseList)); + + auto ivCallback = [&](mlir::Operation *op) { + return genLoopVars(op, converter, loc, iv); + }; + createBodyOfOp( - simdLoopOp, {converter, loc, *nestedEval, - /*genNested=*/true, &loopOpClauseList, iv, - /*outerCombined=*/false, &dsp}); + simdLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval) + .setClauses(&loopOpClauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback)); } static void createWsLoop(Fortran::lower::AbstractConverter &converter, @@ -3360,10 +3416,16 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(beginClauseList)); - createBodyOfOp(wsLoopOp, - {converter, loc, *nestedEval, - /*genNested=*/true, &beginClauseList, iv, - /*outerCombined=*/false, &dsp}); + + auto ivCallback = [&](mlir::Operation *op) { + return genLoopVars(op, converter, loc, iv); + }; + + createBodyOfOp( + wsLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval) + .setClauses(&beginClauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback)); } static void createSimdWsLoop( @@ -3644,8 +3706,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), global.getSymName())); }(); - createBodyOfOp(criticalOp, - {converter, currentLocation, eval}); + auto genInfo = OpWithBodyGenInfo(converter, currentLocation, eval); + createBodyOfOp(criticalOp, genInfo); } static void @@ -3687,11 +3749,11 @@ genOMP(Fortran::lower::AbstractConverter &converter, } // SECTIONS construct - genOpWithBody({converter, eval, - /*genNested=*/false, currentLocation}, - /*reduction_vars=*/mlir::ValueRange(), - /*reductions=*/nullptr, allocateOperands, - allocatorOperands, nowaitClauseOperand); + genOpWithBody( + OpWithBodyGenInfo(converter, currentLocation, eval).setGenNested(false), + /*reduction_vars=*/mlir::ValueRange(), + /*reductions=*/nullptr, allocateOperands, allocatorOperands, + nowaitClauseOperand); const auto §ionBlocks = std::get(sectionsConstruct.t);