diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 2e32abcc42049..4313413065be3 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1353,19 +1353,20 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, const Fortran::parser::AccClauseList &accClauseList) { - mlir::Value ifCond, async; + mlir::Value ifCond, async, waitDevnum; llvm::SmallVector attachEntryOperands, createEntryOperands, - copyEntryOperands, copyoutEntryOperands, dataClauseOperands; + copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands; - // Async has an optional value but can be present with + // Async and wait have an optional value but can be present with // no value as well. When there is no value, the op has an attribute to // represent the clause. bool addAsyncAttr = false; + bool addWaitAttr = false; fir::FirOpBuilder &builder = converter.getFirOpBuilder(); // Lower clauses values mapped to operands. - // Keep track of each group of operands separatly as clauses can appear + // Keep track of each group of operands separately as clauses can appear // more than once. for (const Fortran::parser::AccClause &clause : accClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); @@ -1450,12 +1451,15 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector operandSegments; addOperand(operands, operandSegments, ifCond); addOperand(operands, operandSegments, async); + addOperand(operands, operandSegments, waitDevnum); + addOperands(operands, operandSegments, waitOperands); addOperands(operands, operandSegments, dataClauseOperands); auto dataOp = createRegionOp( builder, currentLocation, operands, operandSegments); dataOp.setAsyncAttr(addAsyncAttr); + dataOp.setAsyncAttr(addWaitAttr); auto insPt = builder.saveInsertionPoint(); builder.setInsertionPointAfter(dataOp); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 5960dfadbc44f..a1af1cf7fea89 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -867,6 +867,9 @@ def OpenACC_DataOp : OpenACC_Op<"data", let arguments = (ins Optional:$ifCond, Optional:$async, UnitAttr:$asyncAttr, + Optional:$waitDevnum, + Variadic:$waitOperands, + UnitAttr:$waitAttr, Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr); @@ -885,6 +888,8 @@ def OpenACC_DataOp : OpenACC_Op<"data", `if` `(` $ifCond `)` | `async` `(` $async `:` type($async) `)` | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` + | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)` + | `wait` `(` $waitOperands `:` type($waitOperands) `)` ) $region attr-dict-with-keyword }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 2ac6899ae7ab8..9c6cffa8399dc 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -868,6 +868,7 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value DataOp::getDataOperand(unsigned i) { unsigned numOptional = getIfCond() ? 1 : 0; numOptional += getAsync() ? 1 : 0; + numOptional += getWaitOperands().size(); return getOperand(numOptional + i); } diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index a1f94323864d6..fa2f8c9f90ca2 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -822,6 +822,17 @@ func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { acc.data async(%a1 : i64) { } attributes { defaultAttr = #acc, async } + acc.data { + } attributes { defaultAttr = #acc, wait } + + %w1 = arith.constant 1 : i64 + acc.data wait(%w1 : i64) { + } attributes { defaultAttr = #acc, wait } + + %wd1 = arith.constant 1 : i64 + acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) { + } attributes { defaultAttr = #acc, wait } + return } @@ -927,6 +938,15 @@ func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { // CHECK: acc.data async(%{{.*}} : i64) { // CHECK-NEXT: } attributes {async, defaultAttr = #acc} +// CHECK: acc.data { +// CHECK-NEXT: } attributes {defaultAttr = #acc, wait} + +// CHECK: acc.data wait(%{{.*}} : i64) { +// CHECK-NEXT: } attributes {defaultAttr = #acc, wait} + +// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) { +// CHECK-NEXT: } attributes {defaultAttr = #acc, wait} + // ----- func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () {