From 9365ed1e10e92c48ad3dbe4b257b0fdc045b74a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 16 Nov 2023 16:41:50 -0800 Subject: [PATCH] [flang][openacc] Add ability to link acc.declare_enter with acc.declare_exit ops (#72476) --- flang/lib/Lower/OpenACC.cpp | 45 ++++++++++++------- .../test/Lower/OpenACC/HLFIR/acc-declare.f90 | 28 ++++++------ .../mlir/Dialect/OpenACC/OpenACCOps.td | 8 +++- .../mlir/Dialect/OpenACC/OpenACCOpsTypes.td | 9 ++++ mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 10 +++-- 5 files changed, 64 insertions(+), 36 deletions(-) diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index e470154ce8c2d..8c6c22210cf08 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -163,7 +163,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder, builder, loc, boxAddrOp.getResult(), asFortran, bounds, /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType()); builder.create( - loc, mlir::ValueRange(entryOp.getAccPtr())); + loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccPtr())); modBuilder.setInsertionPointAfter(registerFuncOp); builder.restoreInsertionPoint(crtInsPt); @@ -195,7 +196,7 @@ static void createDeclareDeallocFuncWithArg( /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType()); builder.create( - loc, mlir::ValueRange(entryOp.getAccPtr())); + loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr())); mlir::Value varPtr; if constexpr (std::is_same_v || @@ -2762,7 +2763,13 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, EntryOp entryOp = createDataEntryOp( builder, loc, addrOp.getResTy(), asFortran, bounds, /*structured=*/false, implicit, clause, addrOp.getResTy().getType()); - builder.create(loc, mlir::ValueRange(entryOp.getAccPtr())); + if constexpr (std::is_same_v) + builder.create( + loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccPtr())); + else + builder.create(loc, mlir::Value{}, + mlir::ValueRange(entryOp.getAccPtr())); mlir::Value varPtr; if constexpr (std::is_same_v) { builder.create(entryOp.getLoc(), entryOp.getAccPtr(), varPtr, @@ -2812,7 +2819,8 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, builder, loc, boxAddrOp.getResult(), asFortran, bounds, /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType()); builder.create( - loc, mlir::ValueRange(entryOp.getAccPtr())); + loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccPtr())); modBuilder.setInsertionPointAfter(registerFuncOp); } @@ -2850,7 +2858,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder, boxAddrOp.getType()); builder.create( - loc, mlir::ValueRange(entryOp.getAccPtr())); + loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr())); mlir::Value varPtr; if constexpr (std::is_same_v || @@ -3092,34 +3100,37 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, mlir::func::FuncOp funcOp = builder.getFunction(); auto ops = funcOp.getOps(); + mlir::Value declareToken; if (ops.empty()) { - builder.create(loc, dataClauseOperands); + declareToken = builder.create( + loc, mlir::acc::DeclareTokenType::get(builder.getContext()), + dataClauseOperands); } else { auto declareOp = *ops.begin(); auto newDeclareOp = builder.create( - loc, declareOp.getDataClauseOperands()); + loc, mlir::acc::DeclareTokenType::get(builder.getContext()), + declareOp.getDataClauseOperands()); newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands); + declareToken = newDeclareOp.getToken(); declareOp.erase(); } openAccCtx.attachCleanup([&builder, loc, createEntryOperands, copyEntryOperands, copyoutEntryOperands, - deviceResidentEntryOperands]() { + deviceResidentEntryOperands, declareToken]() { llvm::SmallVector operands; operands.append(createEntryOperands); operands.append(deviceResidentEntryOperands); operands.append(copyEntryOperands); operands.append(copyoutEntryOperands); - if (!operands.empty()) { - mlir::func::FuncOp funcOp = builder.getFunction(); - auto ops = funcOp.getOps(); - if (ops.empty()) { - builder.create(loc, operands); - } else { - auto declareOp = *ops.begin(); - declareOp.getDataClauseOperandsMutable().append(operands); - } + mlir::func::FuncOp funcOp = builder.getFunction(); + auto ops = funcOp.getOps(); + if (ops.empty()) { + builder.create(loc, declareToken, operands); + } else { + auto declareOp = *ops.begin(); + declareOp.getDataClauseOperandsMutable().append(operands); } genDataExitOperations( diff --git a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 index 489f81d297df4..b0a78fbd5439f 100644 --- a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 +++ b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 @@ -24,11 +24,11 @@ subroutine acc_declare_copy() ! ALL: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index) ! FIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {dataClause = #acc, name = "a"} ! HLFIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]]#1 : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {dataClause = #acc, name = "a"} -! ALL: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref>) +! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref>) ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) { ! ALL: } -! ALL: acc.declare_exit dataOperands(%[[COPYIN]] : !fir.ref>) +! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[COPYIN]] : !fir.ref>) ! FIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref>) bounds(%[[BOUND]]) to varPtr(%[[DECL]] : !fir.ref>) {dataClause = #acc, name = "a"} ! HLFIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref>) bounds(%[[BOUND]]) to varPtr(%[[DECL]]#1 : !fir.ref>) {dataClause = #acc, name = "a"} @@ -51,11 +51,11 @@ subroutine acc_declare_create() ! ALL: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index) ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a"} ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]]#1 : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a"} -! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) +! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) { ! ALL: } -! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref>) +! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref>) ! ALL: acc.delete accPtr(%[[CREATE]] : !fir.ref>) bounds(%[[BOUND]]) {dataClause = #acc, name = "a"} ! ALL: return @@ -119,10 +119,10 @@ subroutine acc_declare_copyout() ! HLFIR: %[[ADECL:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {acc.declare = #acc.declare, uniq_name = "_QMacc_declareFacc_declare_copyoutEa"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) ! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]] : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a"} ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]]#1 : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a"} -! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) +! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32) -! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref>) +! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref>) ! FIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref>) bounds(%{{.*}}) to varPtr(%[[ADECL]] : !fir.ref>) {name = "a"} ! HLFIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref>) bounds(%{{.*}}) to varPtr(%[[ADECL]]#1 : !fir.ref>) {name = "a"} ! ALL: return @@ -178,9 +178,9 @@ subroutine acc_declare_device_resident(a) ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare, uniq_name = "_QMacc_declareFacc_declare_device_residentEa"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {name = "a"} ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {name = "a"} -! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref>) +! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref>) ! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32) -! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref>) +! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref>) ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref>) bounds(%{{.*}}) {dataClause = #acc, name = "a"} subroutine acc_declare_device_resident2() @@ -195,8 +195,8 @@ subroutine acc_declare_device_resident2() ! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOCA]](%{{.*}}) {acc.declare = #acc.declare, uniq_name = "_QMacc_declareFacc_declare_device_resident2Edataparam"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) ! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {name = "dataparam"} ! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {name = "dataparam"} -! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref>) -! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref>) +! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref>) +! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref>) ! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref>) bounds(%{{.*}}) {dataClause = #acc, name = "dataparam"} subroutine acc_declare_link2() @@ -234,10 +234,10 @@ function acc_declare_in_func() ! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 { ! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref>) bounds(%{{.*}}) -> !fir.ref> {name = "a"} -! HLFIR: acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref>) +! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref>) ! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref -! HLFIR: acc.declare_exit dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref>) +! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref>) ! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref>) bounds(%6) {dataClause = #acc, name = "a"} ! HLFIR: return %[[LOAD]] : f32 ! ALL: } @@ -254,10 +254,10 @@ function acc_declare_in_func2(i) ! HLFIR: %[[ALLOCA_A:.*]] = fir.alloca !fir.array<1024xf32> {bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} ! HLFIR: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOCA_A]](%{{.*}}) {acc.declare = #acc.declare, uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) ! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%7) -> !fir.ref> {name = "a"} -! HLFIR: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) +! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref>) ! HLFIR: cf.br ^bb1 ! HLFIR: ^bb1: -! HLFIR: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref>) +! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref>) ! HLFIR: acc.delete accPtr(%[[CREATE]] : !fir.ref>) bounds(%7) {dataClause = #acc, name = "a"} ! ALL: return %{{.*}} : f32 ! ALL: } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index d0b52a0b40241..391e77e0c4081 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1430,6 +1430,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> { }]; let arguments = (ins Variadic:$dataClauseOperands); + let results = (outs OpenACC_DeclareTokenType:$token); let assemblyFormat = [{ oilist( @@ -1441,7 +1442,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> { let hasVerifier = 1; } -def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> { +def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", [AttrSizedOperandSegments]> { let summary = "declare directive - exit from implicit data region"; let description = [{ @@ -1458,10 +1459,13 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> { ``` }]; - let arguments = (ins Variadic:$dataClauseOperands); + let arguments = (ins + Optional:$token, + Variadic:$dataClauseOperands); let assemblyFormat = [{ oilist( + `token` `(` $token `)` | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` ) attr-dict-with-keyword diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td index 4a930ad94c3f1..92ea71a7e8418 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td @@ -24,4 +24,13 @@ def OpenACC_DataBoundsType : OpenACC_Type<"DataBounds", "data_bounds_ty"> { let summary = "Type for representing acc data clause bounds information"; } +def OpenACC_DeclareTokenType : OpenACC_Type<"DeclareToken", "declare_token"> { + let summary = "declare token type"; + let description = [{ + `acc.declare_token` is a type returned by a `declare_enter` operation and + can be passed to a `declare_exit` operation to represent an implicit + data region. + }]; +} + #endif // OPENACC_OPS_TYPES diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6e5df705fee05..08e83cad48220 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1087,9 +1087,10 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); } //===----------------------------------------------------------------------===// template -static LogicalResult checkDeclareOperands(Op &op, - const mlir::ValueRange &operands) { - if (operands.empty()) +static LogicalResult +checkDeclareOperands(Op &op, const mlir::ValueRange &operands, + bool requireAtLeastOneOperand = true) { + if (operands.empty() && requireAtLeastOneOperand) return emitError( op->getLoc(), "at least one operand must appear on the declare operation"); @@ -1151,6 +1152,9 @@ LogicalResult acc::DeclareEnterOp::verify() { //===----------------------------------------------------------------------===// LogicalResult acc::DeclareExitOp::verify() { + if (getToken()) + return checkDeclareOperands(*this, this->getDataClauseOperands(), + /*requireAtLeastOneOperand=*/false); return checkDeclareOperands(*this, this->getDataClauseOperands()); }