Skip to content

[flang][openacc] Lower loop directive to the new acc.loop op design #65417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2024

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Sep 5, 2023

acc.loop was redesigned in https://reviews.llvm.org/D159229. This patch updates the lowering to match the new op.

DO CONCURRENT construct will be added in a follow up patch.

Note that the pre-commit ci will fail until D159229 is merged.

Depends on #67355

@clementval clementval requested a review from a team as a code owner September 5, 2023 22:09
@github-actions github-actions bot added the flang Flang issues not falling into any other category label Sep 5, 2023
Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Thank you!

std::size_t ivTypeSize = ivSym.size();
if (ivTypeSize == 0)
llvm::report_fatal_error("unexpected induction variable size");
return builder.getIntegerType(ivTypeSize * 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment explaining the mapping here would be useful. It is not immediately obvious what ivSym.size() produces or what 8 is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment

*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
else // If `step` is not present, assume it as `1`.
steps.push_back(builder.createIntegerConstant(
currentLocation, builder.getIntegerType(32), 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the same type as bounds be used?

accClauseList);
}
if (loopDirective.v != llvm::acc::ACCD_loop)
llvm::report_fatal_error("Unsupported OpenACC loop construct");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it should be an assert.

@llvmbot
Copy link
Member

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openacc

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

acc.loop was redesigned in https://reviews.llvm.org/D159229. This patch updates the lowering to match the new op.

DO CONCURRENT construct will be added in a follow up patch.

Note that the pre-commit ci will fail until D159229 is merged.

Depends on #67355


Patch is 109.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65417.diff

9 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+5)
  • (modified) flang/lib/Lower/Bridge.cpp (+18-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+126-21)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+52-102)
  • (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+93-120)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+67-114)
  • (modified) flang/test/Lower/OpenACC/acc-private.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+21-24)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+48-90)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..da601c456a1fb2e 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -35,6 +35,7 @@ class FirOpBuilder;
 
 namespace Fortran {
 namespace parser {
+struct AccClauseList;
 struct OpenACCConstruct;
 struct OpenACCDeclarativeConstruct;
 struct OpenACCRoutineConstruct;
@@ -64,6 +65,8 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
 static constexpr llvm::StringRef declarePostDeallocSuffix =
     "_acc_declare_update_desc_post_dealloc";
 
+static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
+
 void genOpenACCConstruct(AbstractConverter &,
                          Fortran::semantics::SemanticsContext &,
                          pft::Evaluation &, const parser::OpenACCConstruct &);
@@ -112,6 +115,8 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
 void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
                           mlir::Location);
 
+int64_t getCollapseValue(const Fortran::parser::AccClauseList &);
+
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 872bf6bc729ecd0..58ae29223165a86 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2383,7 +2383,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
     genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
-    for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+
+    const Fortran::parser::OpenACCLoopConstruct *accLoop =
+        std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+
+    Fortran::lower::pft::Evaluation *curEval = &getEval();
+
+    if (accLoop) {
+      const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
+          std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
+      const Fortran::parser::AccClauseList &clauseList =
+          std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
+      int64_t collapseValue = Fortran::lower::getCollapseValue(clauseList);
+      curEval = &curEval->getFirstNestedEvaluation();
+      for (int64_t i = 1; i < collapseValue; i++)
+        curEval = &*std::next(curEval->getNestedEvaluations().begin());
+    }
+
+    for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
       genFIR(e);
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e470154ce8c2d0b..a73d1639c32f54e 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -753,7 +753,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
     mlir::Type retTy = getTypeFromBounds(bounds, baseAddr.getType());
     if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
       std::string recipeName =
-          fir::getTypeAsString(retTy, converter.getKindMap(), "privatization");
+          fir::getTypeAsString(retTy, converter.getKindMap(),
+                               Fortran::lower::privatizationRecipePrefix);
       recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
                                                         operandLocation, retTy);
       auto op = createDataEntryOp<mlir::acc::PrivateOp>(
@@ -1380,10 +1381,12 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                          Fortran::lower::pft::Evaluation &eval,
                          const llvm::SmallVectorImpl<mlir::Value> &operands,
                          const llvm::SmallVectorImpl<int32_t> &operandSegments,
-                         bool outerCombined = false) {
-  llvm::ArrayRef<mlir::Type> argTy;
-  Op op = builder.create<Op>(loc, argTy, operands);
-  builder.createBlock(&op.getRegion());
+                         bool outerCombined = false,
+                         mlir::TypeRange argsTy = {},
+                         llvm::SmallVector<mlir::Location> locs = {}) {
+  llvm::ArrayRef<mlir::Type> retTy;
+  Op op = builder.create<Op>(loc, retTy, operands);
+  builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
 
@@ -1487,12 +1490,22 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
+                                 const Fortran::semantics::Symbol &ivSym) {
+  std::size_t ivTypeSize = ivSym.size();
+  if (ivTypeSize == 0)
+    llvm::report_fatal_error("unexpected induction variable size");
+  // ivTypeSize is in bytes and IntegerType needs to be in bits.
+  return builder.getIntegerType(ivTypeSize * 8);
+}
+
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
-             Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
+             const Fortran::parser::DoConstruct &outerDoConstruct,
+             Fortran::lower::pft::Evaluation &eval,
              const Fortran::parser::AccClauseList &accClauseList) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
@@ -1501,11 +1514,73 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value gangNum;
   mlir::Value gangDim;
   mlir::Value gangStatic;
-  llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
-      reductionOperands, cacheOperands;
+  llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
+      reductionOperands, cacheOperands, lowerbounds, upperbounds, steps;
   llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
+
   bool hasGang = false, hasVector = false, hasWorker = false;
 
+  llvm::SmallVector<mlir::Type> ivTypes;
+  llvm::SmallVector<mlir::Location> ivLocs;
+  llvm::SmallVector<bool> inclusiveBounds;
+
+  if (outerDoConstruct.IsDoConcurrent())
+    TODO(currentLocation, "OpenACC loop with DO CONCURRENT");
+
+  int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
+  Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
+  for (unsigned i = 0; i < collapseValue; ++i) {
+    const Fortran::parser::LoopControl *loopControl;
+    if (i == 0) {
+      loopControl = &*outerDoConstruct.GetLoopControl();
+    } else {
+      auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
+      assert(doCons && "expect do construct");
+      loopControl = &*doCons->GetLoopControl();
+    }
+
+    const Fortran::parser::LoopControl::Bounds *bounds =
+        std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+    assert(bounds && "Expected bounds on the loop construct");
+    lowerbounds.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
+    upperbounds.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
+    if (bounds->step)
+      steps.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+    else // If `step` is not present, assume it as `1`.
+      steps.push_back(builder.createIntegerConstant(
+          currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
+
+    Fortran::semantics::Symbol &ivSym =
+        bounds->name.thing.symbol->GetUltimate();
+
+    mlir::Type ivTy = getTypeFromIvTypeSize(builder, ivSym);
+    mlir::Value ivValue = converter.getSymbolAddress(ivSym);
+    ivTypes.push_back(ivTy);
+    ivLocs.push_back(currentLocation);
+    std::string recipeName =
+        fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
+                             Fortran::lower::privatizationRecipePrefix);
+    auto recipe = Fortran::lower::createOrGetPrivateRecipe(
+        builder, recipeName, currentLocation, ivValue.getType());
+    std::stringstream asFortran;
+    auto op = createDataEntryOp<mlir::acc::PrivateOp>(
+        builder, currentLocation, ivValue, asFortran, {}, true,
+        /*implicit=*/true, mlir::acc::DataClause::acc_private,
+        ivValue.getType());
+
+    privateOperands.push_back(op.getAccPtr());
+    ivPrivate.push_back(op.getAccPtr());
+    privatizations.push_back(mlir::SymbolRefAttr::get(
+        builder.getContext(), recipe.getSymName().str()));
+    inclusiveBounds.push_back(true);
+    converter.bindSymbol(ivSym, op.getAccPtr());
+    if (i < collapseValue - 1)
+      crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+  }
+
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
@@ -1588,6 +1663,9 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
+  addOperands(operands, operandSegments, lowerbounds);
+  addOperands(operands, operandSegments, upperbounds);
+  addOperands(operands, operandSegments, steps);
   addOperand(operands, operandSegments, gangNum);
   addOperand(operands, operandSegments, gangDim);
   addOperand(operands, operandSegments, gangStatic);
@@ -1599,7 +1677,14 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, reductionOperands);
 
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, eval, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments,
+      /*outerCombined=*/false, ivTypes, ivLocs);
+
+  for (auto [arg, value] : llvm::zip(
+           loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
+    builder.create<fir::StoreOp>(currentLocation, arg, value);
+
+  loopOp.setInclusiveUpperbound(inclusiveBounds);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1660,12 +1745,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
       converter.genLocation(beginLoopDirective.source);
   Fortran::lower::StatementContext stmtCtx;
 
-  if (loopDirective.v == llvm::acc::ACCD_loop) {
-    const auto &accClauseList =
-        std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
-  }
+  assert(loopDirective.v == llvm::acc::ACCD_loop &&
+         "Unsupported OpenACC loop construct");
+
+  const auto &accClauseList =
+      std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
+  const auto &outerDoConstruct =
+      std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
+
+  createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+               *outerDoConstruct, eval, accClauseList);
 }
 
 template <typename Op, typename Clause>
@@ -2241,6 +2330,9 @@ genACC(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
   const auto &accClauseList =
       std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
+  const auto &outerDoConstruct =
+      std::get<std::optional<Fortran::parser::DoConstruct>>(
+          combinedConstruct.t);
 
   mlir::Location currentLocation =
       converter.genLocation(beginCombinedDirective.source);
@@ -2250,20 +2342,20 @@ genACC(Fortran::lower::AbstractConverter &converter,
     createComputeOp<mlir::acc::KernelsOp>(
         converter, currentLocation, eval, semanticsContext, stmtCtx,
         accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
     createComputeOp<mlir::acc::ParallelOp>(
         converter, currentLocation, eval, semanticsContext, stmtCtx,
         accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
     createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
                                          semanticsContext, stmtCtx,
                                          accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else {
     llvm::report_fatal_error("Unknown combined construct encountered");
   }
@@ -3549,3 +3641,16 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
   else
     builder.create<mlir::acc::TerminatorOp>(loc);
 }
+
+int64_t Fortran::lower::getCollapseValue(
+    const Fortran::parser::AccClauseList &clauseList) {
+  for (const Fortran::parser::AccClause &clause : clauseList.v) {
+    if (const auto *collapseClause =
+            std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+      const parser::AccCollapseArg &arg = collapseClause->v;
+      const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
+      return *Fortran::semantics::GetIntValue(collapseValue);
+    }
+  }
+  return 1;
+}
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index eee3f09732a3137..68b27cca38ef2d0 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -44,8 +44,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -58,8 +57,7 @@ subroutine acc_kernels_loop
   !$acc end kernels loop
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -72,8 +70,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[ASYNC1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels async([[ASYNC1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -86,8 +83,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels async([[ASYNC2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -99,8 +95,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -113,8 +108,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -128,8 +122,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
 ! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -143,8 +136,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -157,8 +149,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -171,8 +162,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -185,8 +175,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMWORKERS1:%.*]] = arith.constant 10 : i32
 ! CHECK:      acc.kernels num_workers([[NUMWORKERS1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -199,8 +188,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMWORKERS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels num_workers([[NUMWORKERS2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -213,8 +201,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[VECTORLENGTH1:%.*]] = arith.constant 128 : i32
 ! CHECK:      acc.kernels vector_length([[VECTORLENGTH1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -227,8 +214,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[VECTORLENGTH2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels vector_length([[VECTORLENGTH2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -241,8 +227,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[IF1:%.*]] = arith.constant true
 ! CHECK:      acc.kernels if([[IF1]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -256,8 +241,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
 ! CHECK:      [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
 ! CHECK:      acc.kernels if([[IF2]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -270,8 +254,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[SELF1:%.*]] = arith.constant true
 ! CHECK:      acc.kernels self([[SELF1]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -283,8 +266,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

acc.loop was redesigned in https://reviews.llvm.org/D159229. This patch updates the lowering to match the new op.

DO CONCURRENT construct will be added in a follow up patch.

Note that the pre-commit ci will fail until D159229 is merged.

Depends on #67355


Patch is 109.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65417.diff

9 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+5)
  • (modified) flang/lib/Lower/Bridge.cpp (+18-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+126-21)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+52-102)
  • (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+93-120)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+67-114)
  • (modified) flang/test/Lower/OpenACC/acc-private.f90 (+7-7)
  • (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+21-24)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+48-90)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..da601c456a1fb2e 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -35,6 +35,7 @@ class FirOpBuilder;
 
 namespace Fortran {
 namespace parser {
+struct AccClauseList;
 struct OpenACCConstruct;
 struct OpenACCDeclarativeConstruct;
 struct OpenACCRoutineConstruct;
@@ -64,6 +65,8 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
 static constexpr llvm::StringRef declarePostDeallocSuffix =
     "_acc_declare_update_desc_post_dealloc";
 
+static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
+
 void genOpenACCConstruct(AbstractConverter &,
                          Fortran::semantics::SemanticsContext &,
                          pft::Evaluation &, const parser::OpenACCConstruct &);
@@ -112,6 +115,8 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
 void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
                           mlir::Location);
 
+int64_t getCollapseValue(const Fortran::parser::AccClauseList &);
+
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 872bf6bc729ecd0..58ae29223165a86 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2383,7 +2383,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
     genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
-    for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
+
+    const Fortran::parser::OpenACCLoopConstruct *accLoop =
+        std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+
+    Fortran::lower::pft::Evaluation *curEval = &getEval();
+
+    if (accLoop) {
+      const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
+          std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
+      const Fortran::parser::AccClauseList &clauseList =
+          std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
+      int64_t collapseValue = Fortran::lower::getCollapseValue(clauseList);
+      curEval = &curEval->getFirstNestedEvaluation();
+      for (int64_t i = 1; i < collapseValue; i++)
+        curEval = &*std::next(curEval->getNestedEvaluations().begin());
+    }
+
+    for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
       genFIR(e);
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e470154ce8c2d0b..a73d1639c32f54e 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -753,7 +753,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
     mlir::Type retTy = getTypeFromBounds(bounds, baseAddr.getType());
     if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
       std::string recipeName =
-          fir::getTypeAsString(retTy, converter.getKindMap(), "privatization");
+          fir::getTypeAsString(retTy, converter.getKindMap(),
+                               Fortran::lower::privatizationRecipePrefix);
       recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
                                                         operandLocation, retTy);
       auto op = createDataEntryOp<mlir::acc::PrivateOp>(
@@ -1380,10 +1381,12 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                          Fortran::lower::pft::Evaluation &eval,
                          const llvm::SmallVectorImpl<mlir::Value> &operands,
                          const llvm::SmallVectorImpl<int32_t> &operandSegments,
-                         bool outerCombined = false) {
-  llvm::ArrayRef<mlir::Type> argTy;
-  Op op = builder.create<Op>(loc, argTy, operands);
-  builder.createBlock(&op.getRegion());
+                         bool outerCombined = false,
+                         mlir::TypeRange argsTy = {},
+                         llvm::SmallVector<mlir::Location> locs = {}) {
+  llvm::ArrayRef<mlir::Type> retTy;
+  Op op = builder.create<Op>(loc, retTy, operands);
+  builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
 
@@ -1487,12 +1490,22 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
+                                 const Fortran::semantics::Symbol &ivSym) {
+  std::size_t ivTypeSize = ivSym.size();
+  if (ivTypeSize == 0)
+    llvm::report_fatal_error("unexpected induction variable size");
+  // ivTypeSize is in bytes and IntegerType needs to be in bits.
+  return builder.getIntegerType(ivTypeSize * 8);
+}
+
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
-             Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
+             const Fortran::parser::DoConstruct &outerDoConstruct,
+             Fortran::lower::pft::Evaluation &eval,
              const Fortran::parser::AccClauseList &accClauseList) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
@@ -1501,11 +1514,73 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value gangNum;
   mlir::Value gangDim;
   mlir::Value gangStatic;
-  llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
-      reductionOperands, cacheOperands;
+  llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
+      reductionOperands, cacheOperands, lowerbounds, upperbounds, steps;
   llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
+
   bool hasGang = false, hasVector = false, hasWorker = false;
 
+  llvm::SmallVector<mlir::Type> ivTypes;
+  llvm::SmallVector<mlir::Location> ivLocs;
+  llvm::SmallVector<bool> inclusiveBounds;
+
+  if (outerDoConstruct.IsDoConcurrent())
+    TODO(currentLocation, "OpenACC loop with DO CONCURRENT");
+
+  int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
+  Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
+  for (unsigned i = 0; i < collapseValue; ++i) {
+    const Fortran::parser::LoopControl *loopControl;
+    if (i == 0) {
+      loopControl = &*outerDoConstruct.GetLoopControl();
+    } else {
+      auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
+      assert(doCons && "expect do construct");
+      loopControl = &*doCons->GetLoopControl();
+    }
+
+    const Fortran::parser::LoopControl::Bounds *bounds =
+        std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+    assert(bounds && "Expected bounds on the loop construct");
+    lowerbounds.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
+    upperbounds.push_back(fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
+    if (bounds->step)
+      steps.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+    else // If `step` is not present, assume it as `1`.
+      steps.push_back(builder.createIntegerConstant(
+          currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
+
+    Fortran::semantics::Symbol &ivSym =
+        bounds->name.thing.symbol->GetUltimate();
+
+    mlir::Type ivTy = getTypeFromIvTypeSize(builder, ivSym);
+    mlir::Value ivValue = converter.getSymbolAddress(ivSym);
+    ivTypes.push_back(ivTy);
+    ivLocs.push_back(currentLocation);
+    std::string recipeName =
+        fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
+                             Fortran::lower::privatizationRecipePrefix);
+    auto recipe = Fortran::lower::createOrGetPrivateRecipe(
+        builder, recipeName, currentLocation, ivValue.getType());
+    std::stringstream asFortran;
+    auto op = createDataEntryOp<mlir::acc::PrivateOp>(
+        builder, currentLocation, ivValue, asFortran, {}, true,
+        /*implicit=*/true, mlir::acc::DataClause::acc_private,
+        ivValue.getType());
+
+    privateOperands.push_back(op.getAccPtr());
+    ivPrivate.push_back(op.getAccPtr());
+    privatizations.push_back(mlir::SymbolRefAttr::get(
+        builder.getContext(), recipe.getSymName().str()));
+    inclusiveBounds.push_back(true);
+    converter.bindSymbol(ivSym, op.getAccPtr());
+    if (i < collapseValue - 1)
+      crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+  }
+
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
@@ -1588,6 +1663,9 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
+  addOperands(operands, operandSegments, lowerbounds);
+  addOperands(operands, operandSegments, upperbounds);
+  addOperands(operands, operandSegments, steps);
   addOperand(operands, operandSegments, gangNum);
   addOperand(operands, operandSegments, gangDim);
   addOperand(operands, operandSegments, gangStatic);
@@ -1599,7 +1677,14 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, reductionOperands);
 
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, eval, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments,
+      /*outerCombined=*/false, ivTypes, ivLocs);
+
+  for (auto [arg, value] : llvm::zip(
+           loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
+    builder.create<fir::StoreOp>(currentLocation, arg, value);
+
+  loopOp.setInclusiveUpperbound(inclusiveBounds);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1660,12 +1745,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
       converter.genLocation(beginLoopDirective.source);
   Fortran::lower::StatementContext stmtCtx;
 
-  if (loopDirective.v == llvm::acc::ACCD_loop) {
-    const auto &accClauseList =
-        std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
-  }
+  assert(loopDirective.v == llvm::acc::ACCD_loop &&
+         "Unsupported OpenACC loop construct");
+
+  const auto &accClauseList =
+      std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
+  const auto &outerDoConstruct =
+      std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
+
+  createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+               *outerDoConstruct, eval, accClauseList);
 }
 
 template <typename Op, typename Clause>
@@ -2241,6 +2330,9 @@ genACC(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
   const auto &accClauseList =
       std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
+  const auto &outerDoConstruct =
+      std::get<std::optional<Fortran::parser::DoConstruct>>(
+          combinedConstruct.t);
 
   mlir::Location currentLocation =
       converter.genLocation(beginCombinedDirective.source);
@@ -2250,20 +2342,20 @@ genACC(Fortran::lower::AbstractConverter &converter,
     createComputeOp<mlir::acc::KernelsOp>(
         converter, currentLocation, eval, semanticsContext, stmtCtx,
         accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
     createComputeOp<mlir::acc::ParallelOp>(
         converter, currentLocation, eval, semanticsContext, stmtCtx,
         accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
     createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
                                          semanticsContext, stmtCtx,
                                          accClauseList, /*outerCombined=*/true);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+                 *outerDoConstruct, eval, accClauseList);
   } else {
     llvm::report_fatal_error("Unknown combined construct encountered");
   }
@@ -3549,3 +3641,16 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
   else
     builder.create<mlir::acc::TerminatorOp>(loc);
 }
+
+int64_t Fortran::lower::getCollapseValue(
+    const Fortran::parser::AccClauseList &clauseList) {
+  for (const Fortran::parser::AccClause &clause : clauseList.v) {
+    if (const auto *collapseClause =
+            std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+      const parser::AccCollapseArg &arg = collapseClause->v;
+      const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
+      return *Fortran::semantics::GetIntValue(collapseValue);
+    }
+  }
+  return 1;
+}
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index eee3f09732a3137..68b27cca38ef2d0 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -44,8 +44,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -58,8 +57,7 @@ subroutine acc_kernels_loop
   !$acc end kernels loop
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -72,8 +70,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[ASYNC1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels async([[ASYNC1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -86,8 +83,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels async([[ASYNC2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -99,8 +95,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -113,8 +108,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[WAIT1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels wait([[WAIT1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -128,8 +122,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK:      [[WAIT3:%.*]] = arith.constant 2 : i32
 ! CHECK:      acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -143,8 +136,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -157,8 +149,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMGANGS1:%.*]] = arith.constant 1 : i32
 ! CHECK:      acc.kernels num_gangs([[NUMGANGS1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -171,8 +162,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels num_gangs([[NUMGANGS2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -185,8 +175,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMWORKERS1:%.*]] = arith.constant 10 : i32
 ! CHECK:      acc.kernels num_workers([[NUMWORKERS1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -199,8 +188,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[NUMWORKERS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels num_workers([[NUMWORKERS2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -213,8 +201,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[VECTORLENGTH1:%.*]] = arith.constant 128 : i32
 ! CHECK:      acc.kernels vector_length([[VECTORLENGTH1]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -227,8 +214,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[VECTORLENGTH2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
 ! CHECK:      acc.kernels vector_length([[VECTORLENGTH2]] : i32) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -241,8 +227,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[IF1:%.*]] = arith.constant true
 ! CHECK:      acc.kernels if([[IF1]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -256,8 +241,7 @@ subroutine acc_kernels_loop
 ! CHECK:      [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
 ! CHECK:      [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
 ! CHECK:      acc.kernels if([[IF2]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -270,8 +254,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      [[SELF1:%.*]] = arith.constant true
 ! CHECK:      acc.kernels self([[SELF1]]) {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
@@ -283,8 +266,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
-! CHECK:          fir.do_loop
+! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ...
[truncated]

Copy link

github-actions bot commented Nov 29, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Merge order is sensitive here if you don't merge the updated dialect with updated lowering.

acc.loop was redesigned in https://reviews.llvm.org/D159229. This patch
updates the lowering to match the new op.

DO CONCURRENT construct will be added in a follow up patch.
@clementval clementval merged commit 5062a17 into llvm:main Jan 22, 2024
@clementval clementval deleted the acc_loop_lowering branch January 22, 2024 18:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants