Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][openacc][flang] Support wait devnum and clean async/wait IR #79525

Merged
merged 2 commits into from
Jan 29, 2024

Conversation

clementval
Copy link
Contributor

  • Support wait(devnum: ) with device_type support on all operations that require it
    • devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute inform which sub-segment has a wait(devnum) value.
  • Make async/wait information homogenous on compute ops, data and update op.
    • Unify operands/attributes names across operations and use the same custom parser/printer

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jan 25, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes
  • Support wait(devnum: ) with device_type support on all operations that require it
    • devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute inform which sub-segment has a wait(devnum) value.
  • Make async/wait information homogenous on compute ops, data and update op.
    • Unify operands/attributes names across operations and use the same custom parser/printer

Patch is 55.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79525.diff

14 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+57-36)
  • (modified) flang/test/Lower/OpenACC/acc-data.f90 (+3-3)
  • (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-update.f90 (+1-4)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+62-34)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+198-77)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+1-9)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+4-4)
  • (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+30-4)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
           builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, descTy);
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
           builder, loc, loadOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, loadOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
-              const Fortran::parser::AccClause::Wait *waitClause,
-              llvm::SmallVector<mlir::Value> &waitOperands,
-              llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
-              llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
-              llvm::SmallVector<int32_t> &waitOperandsSegments,
-              mlir::Value &waitDevnum,
-              llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
-              Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+    Fortran::lower::AbstractConverter &converter,
+    const Fortran::parser::AccClause::Wait *waitClause,
+    llvm::SmallVector<mlir::Value> &waitOperands,
+    llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+    llvm::SmallVector<bool> &hasDevnums,
+    llvm::SmallVector<int32_t> &waitOperandsSegments,
+    llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+    Fortran::lower::StatementContext &stmtCtx) {
   const auto &waitClauseValue = waitClause->v;
   if (waitClauseValue) { // wait has a value.
+    llvm::SmallVector<mlir::Value> waitValues;
+
     const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+    const auto &waitDevnumValue =
+        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+    bool hasDevnum = false;
+    if (waitDevnumValue) {
+      waitValues.push_back(fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+      hasDevnum = true;
+    }
+
     const auto &waitList =
         std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    llvm::SmallVector<mlir::Value> waitValues;
     for (const Fortran::parser::ScalarIntExpr &value : waitList) {
       waitValues.push_back(fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(value), stmtCtx)));
     }
+
     for (auto deviceTypeAttr : deviceTypeAttrs) {
       for (auto value : waitValues)
         waitOperands.push_back(value);
       waitOperandsDeviceTypes.push_back(deviceTypeAttr);
       waitOperandsSegments.push_back(waitValues.size());
+      hasDevnums.push_back(hasDevnum);
     }
-
-    // TODO: move to device_type model.
-    const auto &waitDevnumValue =
-        std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    if (waitDevnumValue)
-      waitDevnum = fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
   } else {
     for (auto deviceTypeAttr : deviceTypeAttrs)
       waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
       vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
   llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
       reductionRecipes;
-  mlir::Value waitDevnum; // TODO not yet implemented on compute op.
 
   // Self clause has optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
           builder.getDenseI32ArrayAttr(numGangsSegments));
   }
   if (!asyncDeviceTypes.empty())
-    computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    computeOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
 
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     computeOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
       waitOperandsDeviceTypes, waitOnlyDeviceTypes;
   llvm::SmallVector<int32_t> waitOperandsSegments;
+  llvm::SmallVector<bool> hasWaitDevnums;
 
   bool hasDefaultNone = false;
   bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                      asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
       if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, ifCond);
   addOperands(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
       operandSegments);
 
   if (!asyncDeviceTypes.empty())
-    dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+    dataOp.setAsyncOperandsDeviceTypeAttr(
+        builder.getArrayAttr(asyncDeviceTypes));
   if (!asyncOnlyDeviceTypes.empty())
     dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
   if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   if (!waitOperandsSegments.empty())
     dataOp.setWaitOperandsSegmentsAttr(
         builder.getDenseI32ArrayAttr(waitOperandsSegments));
+  if (!hasWaitDevnums.empty())
+    dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
   if (!waitOnlyDeviceTypes.empty())
     dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
 
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
   return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
 }
 
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+  return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
 static inline mlir::DenseI32ArrayAttr
 getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
                      llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       waitOperands, deviceTypeOperands, asyncOperands;
   llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
       asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<bool> hasWaitDevnums;
   llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                      crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands,
-                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+      genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+                                  waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                                  hasWaitDevnums, waitOperandsSegments,
+                                  crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   builder.create<mlir::acc::UpdateOp>(
       currentLocation, ifCond, asyncOperands,
       getArrayAttr(builder, asyncOperandsDeviceTypes),
-      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
       getDenseI32ArrayAttr(builder, waitOperandsSegments),
       getArrayAttr(builder, waitOperandsDeviceTypes),
+      getBoolArrayAttr(builder, hasWaitDevnums),
       getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
       ifPresent);
 
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
   !$acc data present(a) wait
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
 
   !$acc data present(a) wait(1)
   !$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
   !$acc data present(a) wait(devnum: 0: 1)
   !$acc end data
 
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
 ! CHECK: }{{$}}
 
   !$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.kernels {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
   !$acc kernels wait
   !$acc end kernels
 
-! CHECK:      acc.kernels  {
+! CHECK:      acc.kernels wait {
 ! CHECK:        acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc kernels wait(1)
   !$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
   !$acc parallel wait
   !$acc end parallel
 
-! CHECK:      acc.parallel {
+! CHECK:      acc.parallel wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc parallel wait(1)
   !$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.loop {{.*}} {
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial loop wait(1)
   DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
   !$acc serial wait
   !$acc end serial
 
-! CHECK:      acc.serial {
+! CHECK:      acc.serial wait {
 ! CHECK:        acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
 
   !$acc serial wait(1)
   !$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
   }];
 
   let arguments = (ins
-      Variadic<IntOrIndex>:$async,
-      OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
       Variadic<IntOrIndex>:$waitOperands,
       OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
       OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
       Variadic<IntOrIndex>:$numGangs,
       OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     /// present.
     mlir::Operation::operand_range
     getWaitValues(mlir::acc::DeviceType deviceType);
+    /// Return the wait devnum value clause if present;
+    mlir::Value getWaitDevnum();
+    /// Return the wait devnum value clause for the given device_type if
+    /// present.
+    mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
-      | `async` `(` custom<DeviceTypeOperands>($async,
-            type($async), $asyncDeviceType) `)`
+      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+            type($asyncOp...
[truncated]

@clementval clementval changed the title [mlir][openacc][flang] Support wait devnum and homogenize async/wait IR [mlir][openacc][flang] Support wait devnum and clean async/wait IR Jan 26, 2024
Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thank you!

@clementval clementval merged commit c09dc2d into llvm:main Jan 29, 2024
3 of 4 checks passed
@clementval clementval deleted the acc_wait_devnum branch January 29, 2024 05:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants