Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][flang][openacc] Support device_type on loop construct #76892

Merged
merged 3 commits into from
Jan 5, 2024

Conversation

clementval
Copy link
Contributor

This is adding support for device_type clause representation in the OpenACC MLIR dialect on the acc.loop operation and adjust flang to lower correctly to the new representation.

Each "value" that can be impacted by a device_type clause is now associated with an array attribute that carry this information. This includes:

  • worker clause information
  • gang clause information
  • vector clause information
  • collapse clause information
  • tile clause information

The representation of the gang clause information has been updated and all values are now carried in a single operand segment. This segment is then subdivided by device_type. Each value in a segment is also associated with a GangArgType so it can be differentiated (num/dim/static). This simplify the handling of gang values an limit the number of new attributes needed.

When the clause can be associated with the operation without any value (gang, vector, worker). These are represented by a dedicated attributes with device_type information.

Extra getter functions are provided to make it easier to retrieve a value based on a device_type.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jan 4, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This is adding support for device_type clause representation in the OpenACC MLIR dialect on the acc.loop operation and adjust flang to lower correctly to the new representation.

Each "value" that can be impacted by a device_type clause is now associated with an array attribute that carry this information. This includes:

  • worker clause information
  • gang clause information
  • vector clause information
  • collapse clause information
  • tile clause information

The representation of the gang clause information has been updated and all values are now carried in a single operand segment. This segment is then subdivided by device_type. Each value in a segment is also associated with a GangArgType so it can be differentiated (num/dim/static). This simplify the handling of gang values an limit the number of new attributes needed.

When the clause can be associated with the operation without any value (gang, vector, worker). These are represented by a dedicated attributes with device_type information.

Extra getter functions are provided to make it easier to retrieve a value based on a device_type.


Patch is 82.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76892.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+122-61)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+18-18)
  • (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+20-21)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+18-18)
  • (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+3-3)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+18-18)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+127-23)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+342-96)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+39-18)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+40-40)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecf70818c4ac0f..d8aedca90acf4b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1542,67 +1542,89 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
              const Fortran::parser::AccClauseList &accClauseList,
              bool needEarlyReturnHandling = false) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-
-  mlir::Value workerNum;
-  mlir::Value vectorNum;
-  mlir::Value gangNum;
-  mlir::Value gangDim;
-  mlir::Value gangStatic;
   llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
-      reductionOperands, cacheOperands;
+      reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
+      gangOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
-  bool hasGang = false, hasVector = false, hasWorker = false;
+  llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
+  llvm::SmallVector<int64_t> collapseValues;
+
+  llvm::SmallVector<mlir::Attribute> gangArgTypes;
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
+      autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
+      vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
+      collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
+
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
             std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
       if (gangClause->v) {
+        auto crtGangOperands = gangOperands.size();
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
           if (const auto *num =
                   std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
-            gangNum = fir::getBase(converter.genExprValue(
-                *Fortran::semantics::GetExpr(num->v), stmtCtx));
+            gangOperands.push_back(fir::getBase(converter.genExprValue(
+                *Fortran::semantics::GetExpr(num->v), stmtCtx)));
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Num));
           } else if (const auto *staticArg =
                          std::get_if<Fortran::parser::AccGangArg::Static>(
                              &gangArg.u)) {
             const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
             if (sizeExpr.v) {
-              gangStatic = fir::getBase(converter.genExprValue(
-                  *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx));
+              gangOperands.push_back(fir::getBase(converter.genExprValue(
+                  *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
             } else {
               // * was passed as value and will be represented as a special
               // constant.
-              gangStatic = builder.createIntegerConstant(
-                  clauseLocation, builder.getIndexType(), starCst);
+              gangOperands.push_back(builder.createIntegerConstant(
+                  clauseLocation, builder.getIndexType(), starCst));
             }
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Static));
           } else if (const auto *dim =
                          std::get_if<Fortran::parser::AccGangArg::Dim>(
                              &gangArg.u)) {
-            gangDim = fir::getBase(converter.genExprValue(
-                *Fortran::semantics::GetExpr(dim->v), stmtCtx));
+            gangOperands.push_back(fir::getBase(converter.genExprValue(
+                *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Dim));
           }
         }
+        gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
+        gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasGang = true;
     } else if (const auto *workerClause =
                    std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
       if (workerClause->v) {
-        workerNum = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+        workerNumOperands.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
+        workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasWorker = true;
     } else if (const auto *vectorClause =
                    std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
       if (vectorClause->v) {
-        vectorNum = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+        vectorOperands.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
+        vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        vectorDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasVector = true;
     } else if (const auto *tileClause =
                    std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
       const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
+      auto crtTileOperands = tileOperands.size();
       for (const auto &accTileExpr : accTileExprList.v) {
         const auto &expr =
             std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
@@ -1618,6 +1640,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
           tileOperands.push_back(tileStar);
         }
       }
+      tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
     } else if (const auto *privateClause =
                    std::get_if<Fortran::parser::AccClause::Private>(
                        &clause.u)) {
@@ -1629,17 +1653,46 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
                     reductionOperands, reductionRecipes);
+    } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
+      seqDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (std::get_if<Fortran::parser::AccClause::Independent>(
+                   &clause.u)) {
+      independentDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
+      autoDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+    } else if (const auto *collapseClause =
+                   std::get_if<Fortran::parser::AccClause::Collapse>(
+                       &clause.u)) {
+      const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+      const auto &force = std::get<bool>(arg.t);
+      if (force)
+        TODO(clauseLocation, "OpenACC collapse force modifier");
+      const auto &intExpr =
+          std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+      const auto *expr = Fortran::semantics::GetExpr(intExpr);
+      const std::optional<int64_t> collapseValue =
+          Fortran::evaluate::ToInt64(*expr);
+      assert(collapseValue && "expect integer value for the collapse clause");
+      collapseValues.push_back(*collapseValue);
+      collapseDeviceTypes.push_back(crtDeviceTypeAttr);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
-  addOperand(operands, operandSegments, gangNum);
-  addOperand(operands, operandSegments, gangDim);
-  addOperand(operands, operandSegments, gangStatic);
-  addOperand(operands, operandSegments, workerNum);
-  addOperand(operands, operandSegments, vectorNum);
+  addOperands(operands, operandSegments, gangOperands);
+  addOperands(operands, operandSegments, workerNumOperands);
+  addOperands(operands, operandSegments, vectorOperands);
   addOperands(operands, operandSegments, tileOperands);
   addOperands(operands, operandSegments, cacheOperands);
   addOperands(operands, operandSegments, privateOperands);
@@ -1657,12 +1710,42 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       builder, currentLocation, eval, operands, operandSegments,
       /*outerCombined=*/false, retTy, yieldValue);
 
-  if (hasGang)
-    loopOp.setHasGangAttr(builder.getUnitAttr());
-  if (hasWorker)
-    loopOp.setHasWorkerAttr(builder.getUnitAttr());
-  if (hasVector)
-    loopOp.setHasVectorAttr(builder.getUnitAttr());
+  if (!gangDeviceTypes.empty())
+    loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
+  if (!gangArgTypes.empty())
+    loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
+  if (!gangOperandsSegments.empty())
+    loopOp.setGangOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(gangOperandsSegments));
+  if (!gangOperandsDeviceTypes.empty())
+    loopOp.setGangOperandsDeviceTypeAttr(
+        builder.getArrayAttr(gangOperandsDeviceTypes));
+
+  if (!workerNumDeviceTypes.empty())
+    loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
+  if (!workerNumOperandsDeviceTypes.empty())
+    loopOp.setWorkerNumOperandsDeviceTypeAttr(
+        builder.getArrayAttr(workerNumOperandsDeviceTypes));
+
+  if (!vectorDeviceTypes.empty())
+    loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
+  if (!vectorOperandsDeviceTypes.empty())
+    loopOp.setVectorOperandsDeviceTypeAttr(
+        builder.getArrayAttr(vectorOperandsDeviceTypes));
+
+  if (!tileOperandsDeviceTypes.empty())
+    loopOp.setTileOperandsDeviceTypeAttr(
+        builder.getArrayAttr(tileOperandsDeviceTypes));
+  if (!tileOperandsSegments.empty())
+    loopOp.setTileOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(tileOperandsSegments));
+
+  if (!seqDeviceTypes.empty())
+    loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
+  if (!independentDeviceTypes.empty())
+    loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
+  if (!autoDeviceTypes.empty())
+    loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
 
   if (!privatizations.empty())
     loopOp.setPrivatizationsAttr(
@@ -1672,33 +1755,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
     loopOp.setReductionRecipesAttr(
         mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
 
-  // Lower clauses mapped to attributes
-  for (const Fortran::parser::AccClause &clause : accClauseList.v) {
-    mlir::Location clauseLocation = converter.genLocation(clause.source);
-    if (const auto *collapseClause =
-            std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
-      const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
-      const auto &force = std::get<bool>(arg.t);
-      if (force)
-        TODO(clauseLocation, "OpenACC collapse force modifier");
-      const auto &intExpr =
-          std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
-      const auto *expr = Fortran::semantics::GetExpr(intExpr);
-      const std::optional<int64_t> collapseValue =
-          Fortran::evaluate::ToInt64(*expr);
-      if (collapseValue) {
-        loopOp.setCollapseAttr(builder.getI64IntegerAttr(*collapseValue));
-      }
-    } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      loopOp.setSeqAttr(builder.getUnitAttr());
-    } else if (std::get_if<Fortran::parser::AccClause::Independent>(
-                   &clause.u)) {
-      loopOp.setIndependentAttr(builder.getUnitAttr());
-    } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
-      loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrStrName(),
-                      builder.getUnitAttr());
-    }
-  }
+  if (!collapseValues.empty())
+    loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
+  if (!collapseDeviceTypes.empty())
+    loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
+
   return loopOp;
 }
 
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 93bc699031d550..b17f2e2c80b20f 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -461,7 +461,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {seq}
+! CHECK-NEXT:   } attributes {seq = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -474,7 +474,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {auto}
+! CHECK-NEXT:   } attributes {auto_ = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -487,7 +487,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {independent}
+! CHECK-NEXT:   } attributes {independent = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop gang {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -511,7 +511,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[GANGNUM1:%.*]] = arith.constant 8 : i32
-! CHECK-NEXT:   acc.loop gang(num=[[GANGNUM1]] : i32) {
+! CHECK-NEXT:   acc.loop gang({num=[[GANGNUM1]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -525,7 +525,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK-NEXT:   acc.loop gang(num=[[GANGNUM2]] : i32) {
+! CHECK-NEXT:   acc.loop gang({num=[[GANGNUM2]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -538,7 +538,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) {
+! CHECK:        acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop vector {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop worker {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -624,7 +624,7 @@ subroutine acc_kernels_loop
 ! CHECK:          fir.do_loop
 ! CHECK:            fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {collapse = 2 : i64}
+! CHECK-NEXT:   } attributes {collapse = [2], collapseDeviceType = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -655,7 +655,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZE:%.*]] = arith.constant 2 : i32
-! CHECK:        acc.loop tile([[TILESIZE]] : i32) {
+! CHECK:        acc.loop tile({[[TILESIZE]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -669,7 +669,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZEM1:%.*]] = arith.constant -1 : i32
-! CHECK:        acc.loop tile([[TILESIZEM1]] : i32) {
+! CHECK:        acc.loop tile({[[TILESIZEM1]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -686,7 +686,7 @@ subroutine acc_kernels_loop
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZE1:%.*]] = arith.constant 2 : i32
 ! CHECK:        [[TILESIZE2:%.*]] = arith.constant 2 : i32
-! CHECK:        acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) {
+! CHECK:        acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -699,7 +699,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop tile(%{{.*}} : i32) {
+! CHECK:        acc.loop tile({%{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -714,7 +714,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) {
+! CHECK:        acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index 924574512da4c7..e7f65770498fe2 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -1,6 +1,5 @@
 ! This test checks lowering of OpenACC loop directive.
 
-! RUN: bbc -fopenacc -emit-fir -hlfir=false %s -o - | FileCheck %s
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 ! CHECK-LABEL: acc.private.recipe @privatization_ref_10x10xf32 : !fir.ref<!fir.array<10x10xf32>> init {
@@ -41,7 +40,7 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {seq}
+!CHECK-NEXT: } attributes {seq = [#acc.device_type<none>]}
 
   !$acc loop auto
   DO i = 1, n
@@ -51,7 +50,7 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {auto}
+!CHECK-NEXT: } attributes {auto_ = [#acc.device_type<none>]}
 
   !$acc loop independent
   DO i = 1, n
@@ -61,17 +60,17 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {independent}
+!CHECK-NEXT: } attributes {independent = [#acc.device_type<none>]}
 
   !$acc loop gang
   DO i = 1, n
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop gang {
+!CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: }{{$}}
+!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -79,7 +78,7 @@ program acc_loop
   END DO
 
 !CHECK:      [[GANGNUM1:%.*]] = arith.constant 8 : i32
-!CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) {
+!CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
 !CHECK-NEXT: }{{$}}
@@ -90,7 +89,7 @@ program acc_loop
   END DO
 
 !CHECK:      [[GANGNUM2:%.*]] = fir.load %{{...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This is adding support for device_type clause representation in the OpenACC MLIR dialect on the acc.loop operation and adjust flang to lower correctly to the new representation.

Each "value" that can be impacted by a device_type clause is now associated with an array attribute that carry this information. This includes:

  • worker clause information
  • gang clause information
  • vector clause information
  • collapse clause information
  • tile clause information

The representation of the gang clause information has been updated and all values are now carried in a single operand segment. This segment is then subdivided by device_type. Each value in a segment is also associated with a GangArgType so it can be differentiated (num/dim/static). This simplify the handling of gang values an limit the number of new attributes needed.

When the clause can be associated with the operation without any value (gang, vector, worker). These are represented by a dedicated attributes with device_type information.

Extra getter functions are provided to make it easier to retrieve a value based on a device_type.


Patch is 82.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76892.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+122-61)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+18-18)
  • (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+20-21)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+18-18)
  • (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+3-3)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+18-18)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+127-23)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+342-96)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+39-18)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+40-40)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecf70818c4ac0f..d8aedca90acf4b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1542,67 +1542,89 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
              const Fortran::parser::AccClauseList &accClauseList,
              bool needEarlyReturnHandling = false) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-
-  mlir::Value workerNum;
-  mlir::Value vectorNum;
-  mlir::Value gangNum;
-  mlir::Value gangDim;
-  mlir::Value gangStatic;
   llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
-      reductionOperands, cacheOperands;
+      reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
+      gangOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
-  bool hasGang = false, hasVector = false, hasWorker = false;
+  llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
+  llvm::SmallVector<int64_t> collapseValues;
+
+  llvm::SmallVector<mlir::Attribute> gangArgTypes;
+  llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
+      autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
+      vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
+      collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
+
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None);
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
             std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
       if (gangClause->v) {
+        auto crtGangOperands = gangOperands.size();
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
           if (const auto *num =
                   std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
-            gangNum = fir::getBase(converter.genExprValue(
-                *Fortran::semantics::GetExpr(num->v), stmtCtx));
+            gangOperands.push_back(fir::getBase(converter.genExprValue(
+                *Fortran::semantics::GetExpr(num->v), stmtCtx)));
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Num));
           } else if (const auto *staticArg =
                          std::get_if<Fortran::parser::AccGangArg::Static>(
                              &gangArg.u)) {
             const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
             if (sizeExpr.v) {
-              gangStatic = fir::getBase(converter.genExprValue(
-                  *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx));
+              gangOperands.push_back(fir::getBase(converter.genExprValue(
+                  *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
             } else {
               // * was passed as value and will be represented as a special
               // constant.
-              gangStatic = builder.createIntegerConstant(
-                  clauseLocation, builder.getIndexType(), starCst);
+              gangOperands.push_back(builder.createIntegerConstant(
+                  clauseLocation, builder.getIndexType(), starCst));
             }
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Static));
           } else if (const auto *dim =
                          std::get_if<Fortran::parser::AccGangArg::Dim>(
                              &gangArg.u)) {
-            gangDim = fir::getBase(converter.genExprValue(
-                *Fortran::semantics::GetExpr(dim->v), stmtCtx));
+            gangOperands.push_back(fir::getBase(converter.genExprValue(
+                *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
+            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+                builder.getContext(), mlir::acc::GangArgType::Dim));
           }
         }
+        gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
+        gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasGang = true;
     } else if (const auto *workerClause =
                    std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
       if (workerClause->v) {
-        workerNum = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+        workerNumOperands.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
+        workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasWorker = true;
     } else if (const auto *vectorClause =
                    std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
       if (vectorClause->v) {
-        vectorNum = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+        vectorOperands.push_back(fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
+        vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      } else {
+        vectorDeviceTypes.push_back(crtDeviceTypeAttr);
       }
-      hasVector = true;
     } else if (const auto *tileClause =
                    std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
       const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
+      auto crtTileOperands = tileOperands.size();
       for (const auto &accTileExpr : accTileExprList.v) {
         const auto &expr =
             std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
@@ -1618,6 +1640,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
           tileOperands.push_back(tileStar);
         }
       }
+      tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+      tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
     } else if (const auto *privateClause =
                    std::get_if<Fortran::parser::AccClause::Private>(
                        &clause.u)) {
@@ -1629,17 +1653,46 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
                     reductionOperands, reductionRecipes);
+    } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
+      seqDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (std::get_if<Fortran::parser::AccClause::Independent>(
+                   &clause.u)) {
+      independentDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
+      autoDeviceTypes.push_back(crtDeviceTypeAttr);
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+          deviceTypeClause->v;
+      assert(deviceTypeExprList.v.size() == 1 &&
+             "expect only one device_type expr");
+      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+    } else if (const auto *collapseClause =
+                   std::get_if<Fortran::parser::AccClause::Collapse>(
+                       &clause.u)) {
+      const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+      const auto &force = std::get<bool>(arg.t);
+      if (force)
+        TODO(clauseLocation, "OpenACC collapse force modifier");
+      const auto &intExpr =
+          std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+      const auto *expr = Fortran::semantics::GetExpr(intExpr);
+      const std::optional<int64_t> collapseValue =
+          Fortran::evaluate::ToInt64(*expr);
+      assert(collapseValue && "expect integer value for the collapse clause");
+      collapseValues.push_back(*collapseValue);
+      collapseDeviceTypes.push_back(crtDeviceTypeAttr);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
-  addOperand(operands, operandSegments, gangNum);
-  addOperand(operands, operandSegments, gangDim);
-  addOperand(operands, operandSegments, gangStatic);
-  addOperand(operands, operandSegments, workerNum);
-  addOperand(operands, operandSegments, vectorNum);
+  addOperands(operands, operandSegments, gangOperands);
+  addOperands(operands, operandSegments, workerNumOperands);
+  addOperands(operands, operandSegments, vectorOperands);
   addOperands(operands, operandSegments, tileOperands);
   addOperands(operands, operandSegments, cacheOperands);
   addOperands(operands, operandSegments, privateOperands);
@@ -1657,12 +1710,42 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       builder, currentLocation, eval, operands, operandSegments,
       /*outerCombined=*/false, retTy, yieldValue);
 
-  if (hasGang)
-    loopOp.setHasGangAttr(builder.getUnitAttr());
-  if (hasWorker)
-    loopOp.setHasWorkerAttr(builder.getUnitAttr());
-  if (hasVector)
-    loopOp.setHasVectorAttr(builder.getUnitAttr());
+  if (!gangDeviceTypes.empty())
+    loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
+  if (!gangArgTypes.empty())
+    loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
+  if (!gangOperandsSegments.empty())
+    loopOp.setGangOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(gangOperandsSegments));
+  if (!gangOperandsDeviceTypes.empty())
+    loopOp.setGangOperandsDeviceTypeAttr(
+        builder.getArrayAttr(gangOperandsDeviceTypes));
+
+  if (!workerNumDeviceTypes.empty())
+    loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
+  if (!workerNumOperandsDeviceTypes.empty())
+    loopOp.setWorkerNumOperandsDeviceTypeAttr(
+        builder.getArrayAttr(workerNumOperandsDeviceTypes));
+
+  if (!vectorDeviceTypes.empty())
+    loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
+  if (!vectorOperandsDeviceTypes.empty())
+    loopOp.setVectorOperandsDeviceTypeAttr(
+        builder.getArrayAttr(vectorOperandsDeviceTypes));
+
+  if (!tileOperandsDeviceTypes.empty())
+    loopOp.setTileOperandsDeviceTypeAttr(
+        builder.getArrayAttr(tileOperandsDeviceTypes));
+  if (!tileOperandsSegments.empty())
+    loopOp.setTileOperandsSegmentsAttr(
+        builder.getDenseI32ArrayAttr(tileOperandsSegments));
+
+  if (!seqDeviceTypes.empty())
+    loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
+  if (!independentDeviceTypes.empty())
+    loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
+  if (!autoDeviceTypes.empty())
+    loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
 
   if (!privatizations.empty())
     loopOp.setPrivatizationsAttr(
@@ -1672,33 +1755,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
     loopOp.setReductionRecipesAttr(
         mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
 
-  // Lower clauses mapped to attributes
-  for (const Fortran::parser::AccClause &clause : accClauseList.v) {
-    mlir::Location clauseLocation = converter.genLocation(clause.source);
-    if (const auto *collapseClause =
-            std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
-      const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
-      const auto &force = std::get<bool>(arg.t);
-      if (force)
-        TODO(clauseLocation, "OpenACC collapse force modifier");
-      const auto &intExpr =
-          std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
-      const auto *expr = Fortran::semantics::GetExpr(intExpr);
-      const std::optional<int64_t> collapseValue =
-          Fortran::evaluate::ToInt64(*expr);
-      if (collapseValue) {
-        loopOp.setCollapseAttr(builder.getI64IntegerAttr(*collapseValue));
-      }
-    } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      loopOp.setSeqAttr(builder.getUnitAttr());
-    } else if (std::get_if<Fortran::parser::AccClause::Independent>(
-                   &clause.u)) {
-      loopOp.setIndependentAttr(builder.getUnitAttr());
-    } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
-      loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrStrName(),
-                      builder.getUnitAttr());
-    }
-  }
+  if (!collapseValues.empty())
+    loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
+  if (!collapseDeviceTypes.empty())
+    loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
+
   return loopOp;
 }
 
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 93bc699031d550..b17f2e2c80b20f 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -461,7 +461,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {seq}
+! CHECK-NEXT:   } attributes {seq = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -474,7 +474,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {auto}
+! CHECK-NEXT:   } attributes {auto_ = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -487,7 +487,7 @@ subroutine acc_kernels_loop
 ! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {independent}
+! CHECK-NEXT:   } attributes {independent = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop gang {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -511,7 +511,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[GANGNUM1:%.*]] = arith.constant 8 : i32
-! CHECK-NEXT:   acc.loop gang(num=[[GANGNUM1]] : i32) {
+! CHECK-NEXT:   acc.loop gang({num=[[GANGNUM1]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -525,7 +525,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[GANGNUM2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-! CHECK-NEXT:   acc.loop gang(num=[[GANGNUM2]] : i32) {
+! CHECK-NEXT:   acc.loop gang({num=[[GANGNUM2]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -538,7 +538,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop gang(num=%{{.*}} : i32, static=%{{.*}} : i32) {
+! CHECK:        acc.loop gang({num=%{{.*}} : i32, static=%{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop vector {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop worker {
+! CHECK:        acc.loop {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
+! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -624,7 +624,7 @@ subroutine acc_kernels_loop
 ! CHECK:          fir.do_loop
 ! CHECK:            fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {collapse = 2 : i64}
+! CHECK-NEXT:   } attributes {collapse = [2], collapseDeviceType = [#acc.device_type<none>]}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -655,7 +655,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZE:%.*]] = arith.constant 2 : i32
-! CHECK:        acc.loop tile([[TILESIZE]] : i32) {
+! CHECK:        acc.loop tile({[[TILESIZE]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -669,7 +669,7 @@ subroutine acc_kernels_loop
 
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZEM1:%.*]] = arith.constant -1 : i32
-! CHECK:        acc.loop tile([[TILESIZEM1]] : i32) {
+! CHECK:        acc.loop tile({[[TILESIZEM1]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -686,7 +686,7 @@ subroutine acc_kernels_loop
 ! CHECK:      acc.kernels {
 ! CHECK:        [[TILESIZE1:%.*]] = arith.constant 2 : i32
 ! CHECK:        [[TILESIZE2:%.*]] = arith.constant 2 : i32
-! CHECK:        acc.loop tile([[TILESIZE1]], [[TILESIZE2]] : i32, i32) {
+! CHECK:        acc.loop tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -699,7 +699,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop tile(%{{.*}} : i32) {
+! CHECK:        acc.loop tile({%{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
@@ -714,7 +714,7 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop tile(%{{.*}}, %{{.*}} : i32, i32) {
+! CHECK:        acc.loop tile({%{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index 924574512da4c7..e7f65770498fe2 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -1,6 +1,5 @@
 ! This test checks lowering of OpenACC loop directive.
 
-! RUN: bbc -fopenacc -emit-fir -hlfir=false %s -o - | FileCheck %s
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 ! CHECK-LABEL: acc.private.recipe @privatization_ref_10x10xf32 : !fir.ref<!fir.array<10x10xf32>> init {
@@ -41,7 +40,7 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {seq}
+!CHECK-NEXT: } attributes {seq = [#acc.device_type<none>]}
 
   !$acc loop auto
   DO i = 1, n
@@ -51,7 +50,7 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {auto}
+!CHECK-NEXT: } attributes {auto_ = [#acc.device_type<none>]}
 
   !$acc loop independent
   DO i = 1, n
@@ -61,17 +60,17 @@ program acc_loop
 !CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {independent}
+!CHECK-NEXT: } attributes {independent = [#acc.device_type<none>]}
 
   !$acc loop gang
   DO i = 1, n
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop gang {
+!CHECK:      acc.loop {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: }{{$}}
+!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -79,7 +78,7 @@ program acc_loop
   END DO
 
 !CHECK:      [[GANGNUM1:%.*]] = arith.constant 8 : i32
-!CHECK-NEXT: acc.loop gang(num=[[GANGNUM1]] : i32) {
+!CHECK-NEXT: acc.loop gang({num=[[GANGNUM1]] : i32}) {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
 !CHECK-NEXT: }{{$}}
@@ -90,7 +89,7 @@ program acc_loop
   END DO
 
 !CHECK:      [[GANGNUM2:%.*]] = fir.load %{{...
[truncated]

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thank you!

@clementval clementval merged commit e456689 into llvm:main Jan 5, 2024
4 checks passed
@clementval clementval deleted the acc_device_type_loop branch January 18, 2024 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants