Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openacc] Add ability to link acc.declare_enter with acc.declare_exit ops #72476

Merged
merged 3 commits into from
Nov 17, 2023

Conversation

clementval
Copy link
Contributor

Add a result token to acc.declare_enter op that can be feed to acc.declare_exit op to represent an implicit declare data region in a function/subroutine.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Nov 16, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 16, 2023

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add a result token to acc.declare_enter op that can be feed to acc.declare_exit op to represent an implicit declare data region in a function/subroutine.


Full diff: https://github.com/llvm/llvm-project/pull/72476.diff

5 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+28-17)
  • (modified) flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 (+14-14)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+6-2)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td (+9)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+6-2)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e470154ce8c2d0b..8c6c22210cf0894 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -163,7 +163,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
   builder.restoreInsertionPoint(crtInsPt);
@@ -195,7 +196,7 @@ static void createDeclareDeallocFuncWithArg(
           /*structured=*/false, /*implicit=*/false, clause,
           boxAddrOp.getType());
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -2762,7 +2763,13 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
   EntryOp entryOp = createDataEntryOp<EntryOp>(
       builder, loc, addrOp.getResTy(), asFortran, bounds,
       /*structured=*/false, implicit, clause, addrOp.getResTy().getType());
-  builder.create<DeclareOp>(loc, mlir::ValueRange(entryOp.getAccPtr()));
+  if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
+    builder.create<DeclareOp>(
+        loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+        mlir::ValueRange(entryOp.getAccPtr()));
+  else
+    builder.create<DeclareOp>(loc, mlir::Value{},
+                              mlir::ValueRange(entryOp.getAccPtr()));
   mlir::Value varPtr;
   if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
     builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
@@ -2812,7 +2819,8 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
 }
@@ -2850,7 +2858,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           boxAddrOp.getType());
 
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -3092,34 +3100,37 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
 
   mlir::func::FuncOp funcOp = builder.getFunction();
   auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
+  mlir::Value declareToken;
   if (ops.empty()) {
-    builder.create<mlir::acc::DeclareEnterOp>(loc, dataClauseOperands);
+    declareToken = builder.create<mlir::acc::DeclareEnterOp>(
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        dataClauseOperands);
   } else {
     auto declareOp = *ops.begin();
     auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
-        loc, declareOp.getDataClauseOperands());
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        declareOp.getDataClauseOperands());
     newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
+    declareToken = newDeclareOp.getToken();
     declareOp.erase();
   }
 
   openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
                             copyEntryOperands, copyoutEntryOperands,
-                            deviceResidentEntryOperands]() {
+                            deviceResidentEntryOperands, declareToken]() {
     llvm::SmallVector<mlir::Value> operands;
     operands.append(createEntryOperands);
     operands.append(deviceResidentEntryOperands);
     operands.append(copyEntryOperands);
     operands.append(copyoutEntryOperands);
 
-    if (!operands.empty()) {
-      mlir::func::FuncOp funcOp = builder.getFunction();
-      auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
-      if (ops.empty()) {
-        builder.create<mlir::acc::DeclareExitOp>(loc, operands);
-      } else {
-        auto declareOp = *ops.begin();
-        declareOp.getDataClauseOperandsMutable().append(operands);
-      }
+    mlir::func::FuncOp funcOp = builder.getFunction();
+    auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
+    if (ops.empty()) {
+      builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
+    } else {
+      auto declareOp = *ops.begin();
+      declareOp.getDataClauseOperandsMutable().append(operands);
     }
 
     genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
diff --git a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
index 489f81d297df4d2..b0a78fbd5439fb9 100644
--- a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
+++ b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
@@ -24,11 +24,11 @@ subroutine acc_declare_copy()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) to varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 
@@ -51,11 +51,11 @@ subroutine acc_declare_create()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL: return
 
@@ -119,10 +119,10 @@ subroutine acc_declare_copyout()
 ! HLFIR: %[[ADECL:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_copyout>, uniq_name = "_QMacc_declareFacc_declare_copyoutEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
 
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) to varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) to varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! ALL: return
@@ -178,9 +178,9 @@ subroutine acc_declare_device_resident(a)
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_residentEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 
   subroutine acc_declare_device_resident2()
@@ -195,8 +195,8 @@ subroutine acc_declare_device_resident2()
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOCA]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_resident2Edataparam"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xf32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "dataparam"}
 
   subroutine acc_declare_link2()
@@ -234,10 +234,10 @@ function acc_declare_in_func()
 
 ! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 {
 ! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 
 ! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref<f32>
-! HLFIR: acc.declare_exit dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) bounds(%6) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 ! HLFIR: return %[[LOAD]] : f32
 ! ALL: }
@@ -254,10 +254,10 @@ function acc_declare_in_func2(i)
 ! HLFIR: %[[ALLOCA_A:.*]] = fir.alloca !fir.array<1024xf32> {bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"}
 ! HLFIR: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOCA_A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_create>, uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} : (!fir.ref<!fir.array<1024xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xf32>>, !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%7) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR:   cf.br ^bb1
 ! HLFIR: ^bb1:
-! HLFIR: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>) bounds(%7) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL:   return %{{.*}} : f32
 ! ALL: }
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index d0b52a0b4024172..391e77e0c4081a3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1430,6 +1430,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   }];
 
   let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let results = (outs OpenACC_DeclareTokenType:$token);
 
   let assemblyFormat = [{
     oilist(
@@ -1441,7 +1442,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   let hasVerifier = 1;
 }
 
-def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
+def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", [AttrSizedOperandSegments]> {
   let summary = "declare directive - exit from implicit data region";
 
   let description = [{
@@ -1458,10 +1459,13 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
     ```
   }];
 
-  let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let arguments = (ins
+      Optional<OpenACC_DeclareTokenType>:$token,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
 
   let assemblyFormat = [{
     oilist(
+        `token` `(` $token `)` |
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
index 4a930ad94c3f175..92ea71a7e8418de 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
@@ -24,4 +24,13 @@ def OpenACC_DataBoundsType : OpenACC_Type<"DataBounds", "data_bounds_ty"> {
   let summary = "Type for representing acc data clause bounds information";
 }
 
+def OpenACC_DeclareTokenType : OpenACC_Type<"DeclareToken", "declare_token"> {
+  let summary = "declare token type";
+  let description = [{
+    `acc.declare_token` is a type returned by a `declare_enter` operation and
+    can be passed to a `declare_exit` operation to represent an implicit
+    data region.
+  }];
+}
+
 #endif // OPENACC_OPS_TYPES
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6e5df705fee05d8..b30ffe638999f9f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1088,8 +1088,9 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
 
 template <typename Op>
 static LogicalResult checkDeclareOperands(Op &op,
-                                          const mlir::ValueRange &operands) {
-  if (operands.empty())
+                                          const mlir::ValueRange &operands,
+                                          bool requireAtLeastOnOperand = true) {
+  if (operands.empty() && requireAtLeastOnOperand)
     return emitError(
         op->getLoc(),
         "at least one operand must appear on the declare operation");
@@ -1151,6 +1152,9 @@ LogicalResult acc::DeclareEnterOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::DeclareExitOp::verify() {
+  if (getToken())
+    return checkDeclareOperands(*this, this->getDataClauseOperands(),
+                                /*requireAtLeastOneOperand=*/false);
   return checkDeclareOperands(*this, this->getDataClauseOperands());
 }
 

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 16, 2023

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add a result token to acc.declare_enter op that can be feed to acc.declare_exit op to represent an implicit declare data region in a function/subroutine.


Full diff: https://github.com/llvm/llvm-project/pull/72476.diff

5 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+28-17)
  • (modified) flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 (+14-14)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+6-2)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td (+9)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+6-2)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e470154ce8c2d0b..8c6c22210cf0894 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -163,7 +163,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
   builder.restoreInsertionPoint(crtInsPt);
@@ -195,7 +196,7 @@ static void createDeclareDeallocFuncWithArg(
           /*structured=*/false, /*implicit=*/false, clause,
           boxAddrOp.getType());
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -2762,7 +2763,13 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
   EntryOp entryOp = createDataEntryOp<EntryOp>(
       builder, loc, addrOp.getResTy(), asFortran, bounds,
       /*structured=*/false, implicit, clause, addrOp.getResTy().getType());
-  builder.create<DeclareOp>(loc, mlir::ValueRange(entryOp.getAccPtr()));
+  if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
+    builder.create<DeclareOp>(
+        loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+        mlir::ValueRange(entryOp.getAccPtr()));
+  else
+    builder.create<DeclareOp>(loc, mlir::Value{},
+                              mlir::ValueRange(entryOp.getAccPtr()));
   mlir::Value varPtr;
   if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
     builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
@@ -2812,7 +2819,8 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
       builder, loc, boxAddrOp.getResult(), asFortran, bounds,
       /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
   builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+      mlir::ValueRange(entryOp.getAccPtr()));
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
 }
@@ -2850,7 +2858,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           boxAddrOp.getType());
 
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
 
   mlir::Value varPtr;
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
@@ -3092,34 +3100,37 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
 
   mlir::func::FuncOp funcOp = builder.getFunction();
   auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
+  mlir::Value declareToken;
   if (ops.empty()) {
-    builder.create<mlir::acc::DeclareEnterOp>(loc, dataClauseOperands);
+    declareToken = builder.create<mlir::acc::DeclareEnterOp>(
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        dataClauseOperands);
   } else {
     auto declareOp = *ops.begin();
     auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
-        loc, declareOp.getDataClauseOperands());
+        loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
+        declareOp.getDataClauseOperands());
     newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
+    declareToken = newDeclareOp.getToken();
     declareOp.erase();
   }
 
   openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
                             copyEntryOperands, copyoutEntryOperands,
-                            deviceResidentEntryOperands]() {
+                            deviceResidentEntryOperands, declareToken]() {
     llvm::SmallVector<mlir::Value> operands;
     operands.append(createEntryOperands);
     operands.append(deviceResidentEntryOperands);
     operands.append(copyEntryOperands);
     operands.append(copyoutEntryOperands);
 
-    if (!operands.empty()) {
-      mlir::func::FuncOp funcOp = builder.getFunction();
-      auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
-      if (ops.empty()) {
-        builder.create<mlir::acc::DeclareExitOp>(loc, operands);
-      } else {
-        auto declareOp = *ops.begin();
-        declareOp.getDataClauseOperandsMutable().append(operands);
-      }
+    mlir::func::FuncOp funcOp = builder.getFunction();
+    auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
+    if (ops.empty()) {
+      builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
+    } else {
+      auto declareOp = *ops.begin();
+      declareOp.getDataClauseOperandsMutable().append(operands);
     }
 
     genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
diff --git a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
index 489f81d297df4d2..b0a78fbd5439fb9 100644
--- a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
+++ b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
@@ -24,11 +24,11 @@ subroutine acc_declare_copy()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) to varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
 
@@ -51,11 +51,11 @@ subroutine acc_declare_create()
 ! ALL: %[[BOUND:.*]] = acc.bounds   lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
 ! ALL: }
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL: return
 
@@ -119,10 +119,10 @@ subroutine acc_declare_copyout()
 ! HLFIR: %[[ADECL:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_copyout>, uniq_name = "_QMacc_declareFacc_declare_copyoutEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
 
-! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
 ! FIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) to varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! HLFIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) to varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>) {name = "a"}
 ! ALL: return
@@ -178,9 +178,9 @@ subroutine acc_declare_device_resident(a)
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_residentEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 
   subroutine acc_declare_device_resident2()
@@ -195,8 +195,8 @@ subroutine acc_declare_device_resident2()
 ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOCA]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_resident2Edataparam"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
 ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
 ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xf32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
-! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
-! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
+! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
 ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "dataparam"}
 
   subroutine acc_declare_link2()
@@ -234,10 +234,10 @@ function acc_declare_in_func()
 
 ! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 {
 ! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 
 ! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref<f32>
-! HLFIR: acc.declare_exit dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) bounds(%6) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
 ! HLFIR: return %[[LOAD]] : f32
 ! ALL: }
@@ -254,10 +254,10 @@ function acc_declare_in_func2(i)
 ! HLFIR: %[[ALLOCA_A:.*]] = fir.alloca !fir.array<1024xf32> {bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"}
 ! HLFIR: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOCA_A]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_create>, uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} : (!fir.ref<!fir.array<1024xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xf32>>, !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%7) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
-! HLFIR: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR:   cf.br ^bb1
 ! HLFIR: ^bb1:
-! HLFIR: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
+! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
 ! HLFIR: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>) bounds(%7) {dataClause = #acc<data_clause acc_create>, name = "a"}
 ! ALL:   return %{{.*}} : f32
 ! ALL: }
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index d0b52a0b4024172..391e77e0c4081a3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1430,6 +1430,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   }];
 
   let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let results = (outs OpenACC_DeclareTokenType:$token);
 
   let assemblyFormat = [{
     oilist(
@@ -1441,7 +1442,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
   let hasVerifier = 1;
 }
 
-def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
+def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", [AttrSizedOperandSegments]> {
   let summary = "declare directive - exit from implicit data region";
 
   let description = [{
@@ -1458,10 +1459,13 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
     ```
   }];
 
-  let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
+  let arguments = (ins
+      Optional<OpenACC_DeclareTokenType>:$token,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
 
   let assemblyFormat = [{
     oilist(
+        `token` `(` $token `)` |
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
index 4a930ad94c3f175..92ea71a7e8418de 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
@@ -24,4 +24,13 @@ def OpenACC_DataBoundsType : OpenACC_Type<"DataBounds", "data_bounds_ty"> {
   let summary = "Type for representing acc data clause bounds information";
 }
 
+def OpenACC_DeclareTokenType : OpenACC_Type<"DeclareToken", "declare_token"> {
+  let summary = "declare token type";
+  let description = [{
+    `acc.declare_token` is a type returned by a `declare_enter` operation and
+    can be passed to a `declare_exit` operation to represent an implicit
+    data region.
+  }];
+}
+
 #endif // OPENACC_OPS_TYPES
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6e5df705fee05d8..b30ffe638999f9f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1088,8 +1088,9 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
 
 template <typename Op>
 static LogicalResult checkDeclareOperands(Op &op,
-                                          const mlir::ValueRange &operands) {
-  if (operands.empty())
+                                          const mlir::ValueRange &operands,
+                                          bool requireAtLeastOnOperand = true) {
+  if (operands.empty() && requireAtLeastOnOperand)
     return emitError(
         op->getLoc(),
         "at least one operand must appear on the declare operation");
@@ -1151,6 +1152,9 @@ LogicalResult acc::DeclareEnterOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult acc::DeclareExitOp::verify() {
+  if (getToken())
+    return checkDeclareOperands(*this, this->getDataClauseOperands(),
+                                /*requireAtLeastOneOperand=*/false);
   return checkDeclareOperands(*this, this->getDataClauseOperands());
 }
 

Copy link

github-actions bot commented Nov 16, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@clementval clementval merged commit 9365ed1 into llvm:main Nov 17, 2023
3 checks passed
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
@clementval clementval deleted the acc_declare_token branch January 18, 2024 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants