diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index 409956f0ecb30..f23e4726f33e0 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix = static constexpr llvm::StringRef declarePostDeallocSuffix = "_acc_declare_update_desc_post_dealloc"; -void genOpenACCConstruct(AbstractConverter &, - Fortran::semantics::SemanticsContext &, - pft::Evaluation &, const parser::OpenACCConstruct &); +mlir::Value genOpenACCConstruct(AbstractConverter &, + Fortran::semantics::SemanticsContext &, + pft::Evaluation &, + const parser::OpenACCConstruct &); void genOpenACCDeclarativeConstruct(AbstractConverter &, Fortran::semantics::SemanticsContext &, StatementContext &, @@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &, void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *, mlir::Location); +bool isInOpenACCLoop(fir::FirOpBuilder &); + +void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &); + +void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 23c48cc7bd978..45da1355df168 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter { void genFIR(const Fortran::parser::OpenACCConstruct &acc) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); localSymbols.pushScope(); - genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc); + mlir::Value exitCond = genOpenACCConstruct( + *this, bridge.getSemanticsContext(), getEval(), acc); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + const Fortran::parser::OpenACCLoopConstruct *accLoop = + std::get_if(&acc.u); + if (accLoop && exitCond) { + Fortran::lower::pft::FunctionLikeUnit *funit = + getEval().getOwningProcedure(); + assert(funit && "not inside main program, function or subroutine"); + mlir::Block *continueBlock = + builder->getBlock()->splitBlock(builder->getBlock()->end()); + builder->create(toLocation(), exitCond, + funit->finalBlock, continueBlock); + builder->setInsertionPointToEnd(continueBlock); + } } void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) { @@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter { // Branch to the last block of the SUBROUTINE, which has the actual return. if (!funit->finalBlock) { mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint(); + Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder); funit->finalBlock = builder->createBlock(&builder->getRegion()); builder->restoreInsertionPoint(insPt); } - builder->create(loc, funit->finalBlock); + + if (Fortran::lower::isInOpenACCLoop(*builder)) + Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc); + else + builder->create(loc, funit->finalBlock); } void genFIR(const Fortran::parser::CycleStmt &) { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 8c6c22210cf08..e2abed1b9f4f6 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -25,10 +25,12 @@ #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/Todo.h" +#include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/tools.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "llvm/Frontend/OpenACC/ACC.h.inc" // Special value for * passed in device_type or gang clauses. @@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::pft::Evaluation &eval, const llvm::SmallVectorImpl &operands, const llvm::SmallVectorImpl &operandSegments, - bool outerCombined = false) { - llvm::ArrayRef argTy; - Op op = builder.create(loc, argTy, operands); + bool outerCombined = false, + llvm::SmallVector retTy = {}, + mlir::Value yieldValue = {}) { + Op op = builder.create(loc, retTy, operands); builder.createBlock(&op.getRegion()); mlir::Block &block = op.getRegion().back(); builder.setInsertionPointToStart(&block); @@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc, mlir::acc::YieldOp>( builder, eval.getNestedEvaluations()); - builder.create(loc); + if (yieldValue) { + if constexpr (std::is_same_v) { + Terminator yieldOp = builder.create(loc, yieldValue); + yieldValue.getDefiningOp()->moveBefore(yieldOp); + } else { + builder.create(loc); + } + } else { + builder.create(loc); + } builder.setInsertionPointToStart(&block); return op; } @@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::AccClauseList &accClauseList) { + const Fortran::parser::AccClauseList &accClauseList, + bool needEarlyReturnHandling = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Value workerNum; @@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, addOperands(operands, operandSegments, privateOperands); addOperands(operands, operandSegments, reductionOperands); + llvm::SmallVector retTy; + mlir::Value yieldValue; + if (needEarlyReturnHandling) { + mlir::Type i1Ty = builder.getI1Type(); + yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0); + retTy.push_back(i1Ty); + } + auto loopOp = createRegionOp( - builder, currentLocation, eval, operands, operandSegments); + builder, currentLocation, eval, operands, operandSegments, + /*outerCombined=*/false, retTy, yieldValue); if (hasGang) loopOp.setHasGangAttr(builder.getUnitAttr()); @@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, return loopOp; } -static void genACC(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { +static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) { + bool hasReturnStmt = false; + for (auto &e : eval.getNestedEvaluations()) { + e.visit(Fortran::common::visitors{ + [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; }, + [&](const auto &s) {}, + }); + if (e.hasNestedEvaluations()) + hasReturnStmt = hasEarlyReturn(e); + } + return hasReturnStmt; +} + +static mlir::Value +genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); const auto &loopDirective = std::get(beginLoopDirective.t); + bool needEarlyExitHandling = false; + if (eval.lowerAsUnstructured()) + needEarlyExitHandling = hasEarlyReturn(eval); + mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); Fortran::lower::StatementContext stmtCtx; @@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter, if (loopDirective.v == llvm::acc::ACCD_loop) { const auto &accClauseList = std::get(beginLoopDirective.t); - createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx, - accClauseList); + auto loopOp = + createLoopOp(converter, currentLocation, eval, semanticsContext, + stmtCtx, accClauseList, needEarlyExitHandling); + if (needEarlyExitHandling) + return loopOp.getResult(0); } + return mlir::Value{}; } template @@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter, builder.restoreInsertionPoint(crtPos); } -void Fortran::lower::genOpenACCConstruct( +mlir::Value Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenACCConstruct &accConstruct) { + mlir::Value exitCond; std::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { @@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct( genACC(converter, semanticsContext, eval, combinedConstruct); }, [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { - genACC(converter, semanticsContext, eval, loopConstruct); + exitCond = genACC(converter, semanticsContext, eval, loopConstruct); }, [&](const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { @@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct( }, }, accConstruct.u); + return exitCond; } void Fortran::lower::genOpenACCDeclarativeConstruct( @@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder, else builder.create(loc); } + +bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) { + if (builder.getBlock()->getParent()->getParentOfType()) + return true; + return false; +} + +void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside( + fir::FirOpBuilder &builder) { + if (auto loopOp = + builder.getBlock()->getParent()->getParentOfType()) + builder.setInsertionPointAfter(loopOp); +} + +void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, + mlir::Location loc) { + mlir::Value yieldValue = + builder.createIntegerConstant(loc, builder.getI1Type(), 1); + builder.create(loc, yieldValue); +} diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90 new file mode 100644 index 0000000000000..75f1c30733272 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90 @@ -0,0 +1,37 @@ +! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s + +subroutine sub1(x, a) + real :: x(200) + integer :: a + + !$acc loop + do i = 100, 200 + x(i) = 1.0 + if (i == a) return + end do + + i = 2 +end + +! CHECK-LABEL: func.func @_QPsub1 +! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[EXIT_COND:.*]] = acc.loop { +! CHECK: ^bb{{.*}}: +! CHECK: ^bb{{.*}}: +! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref +! CHECK: %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32 +! CHECK: cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]] +! CHECK: ^[[EARLY_RET]]: +! CHECK: acc.yield %true : i1 +! CHECK: ^[[NO_RET]]: +! CHECK: cf.br ^bb{{.*}} +! CHECK: ^bb{{.*}}: +! CHECK: acc.yield %false : i1 +! CHECK: }(i1) +! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]] +! CHECK: ^[[CONTINUE_BLOCK]]: +! CHECK: hlfir.assign +! CHECK: cf.br ^[[EXIT_BLOCK]] +! CHECK: ^[[EXIT_BLOCK]]: +! CHECK: return +! CHECK: }