Skip to content

Commit

Permalink
[flang][openacc] Add ability to link acc.declare_enter with acc.decla…
Browse files Browse the repository at this point in the history
…re_exit ops (#72476)
  • Loading branch information
clementval committed Nov 17, 2023
1 parent c6f7b63 commit 9365ed1
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 36 deletions.
45 changes: 28 additions & 17 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
/*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
builder.create<mlir::acc::DeclareEnterOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));
loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
mlir::ValueRange(entryOp.getAccPtr()));

modBuilder.setInsertionPointAfter(registerFuncOp);
builder.restoreInsertionPoint(crtInsPt);
Expand Down Expand Up @@ -195,7 +196,7 @@ static void createDeclareDeallocFuncWithArg(
/*structured=*/false, /*implicit=*/false, clause,
boxAddrOp.getType());
builder.create<mlir::acc::DeclareExitOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));
loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));

mlir::Value varPtr;
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
Expand Down Expand Up @@ -2762,7 +2763,13 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
EntryOp entryOp = createDataEntryOp<EntryOp>(
builder, loc, addrOp.getResTy(), asFortran, bounds,
/*structured=*/false, implicit, clause, addrOp.getResTy().getType());
builder.create<DeclareOp>(loc, mlir::ValueRange(entryOp.getAccPtr()));
if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>)
builder.create<DeclareOp>(
loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
mlir::ValueRange(entryOp.getAccPtr()));
else
builder.create<DeclareOp>(loc, mlir::Value{},
mlir::ValueRange(entryOp.getAccPtr()));
mlir::Value varPtr;
if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) {
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
Expand Down Expand Up @@ -2812,7 +2819,8 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
/*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
builder.create<mlir::acc::DeclareEnterOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));
loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
mlir::ValueRange(entryOp.getAccPtr()));

modBuilder.setInsertionPointAfter(registerFuncOp);
}
Expand Down Expand Up @@ -2850,7 +2858,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
boxAddrOp.getType());

builder.create<mlir::acc::DeclareExitOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));
loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));

mlir::Value varPtr;
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
Expand Down Expand Up @@ -3092,34 +3100,37 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,

mlir::func::FuncOp funcOp = builder.getFunction();
auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>();
mlir::Value declareToken;
if (ops.empty()) {
builder.create<mlir::acc::DeclareEnterOp>(loc, dataClauseOperands);
declareToken = builder.create<mlir::acc::DeclareEnterOp>(
loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
dataClauseOperands);
} else {
auto declareOp = *ops.begin();
auto newDeclareOp = builder.create<mlir::acc::DeclareEnterOp>(
loc, declareOp.getDataClauseOperands());
loc, mlir::acc::DeclareTokenType::get(builder.getContext()),
declareOp.getDataClauseOperands());
newDeclareOp.getDataClauseOperandsMutable().append(dataClauseOperands);
declareToken = newDeclareOp.getToken();
declareOp.erase();
}

openAccCtx.attachCleanup([&builder, loc, createEntryOperands,
copyEntryOperands, copyoutEntryOperands,
deviceResidentEntryOperands]() {
deviceResidentEntryOperands, declareToken]() {
llvm::SmallVector<mlir::Value> operands;
operands.append(createEntryOperands);
operands.append(deviceResidentEntryOperands);
operands.append(copyEntryOperands);
operands.append(copyoutEntryOperands);

if (!operands.empty()) {
mlir::func::FuncOp funcOp = builder.getFunction();
auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
if (ops.empty()) {
builder.create<mlir::acc::DeclareExitOp>(loc, operands);
} else {
auto declareOp = *ops.begin();
declareOp.getDataClauseOperandsMutable().append(operands);
}
mlir::func::FuncOp funcOp = builder.getFunction();
auto ops = funcOp.getOps<mlir::acc::DeclareExitOp>();
if (ops.empty()) {
builder.create<mlir::acc::DeclareExitOp>(loc, declareToken, operands);
} else {
auto declareOp = *ops.begin();
declareOp.getDataClauseOperandsMutable().append(operands);
}

genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
Expand Down
28 changes: 14 additions & 14 deletions flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ subroutine acc_declare_copy()
! ALL: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
! FIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
! HLFIR: %[[COPYIN:.*]] = acc.copyin varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copy>, name = "a"}
! ALL: acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)

! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
! ALL: }
! ALL: acc.declare_exit dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>)
! FIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}
! HLFIR: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_copy>, name = "a"}

Expand All @@ -51,11 +51,11 @@ subroutine acc_declare_create()
! ALL: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)

! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index, i32) {
! ALL: }
! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) {dataClause = #acc<data_clause acc_create>, name = "a"}
! ALL: return

Expand Down Expand Up @@ -119,10 +119,10 @@ subroutine acc_declare_copyout()
! HLFIR: %[[ADECL:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {acc.declare = #acc.declare<dataClause = acc_copyout>, uniq_name = "_QMacc_declareFacc_declare_copyoutEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
! FIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {dataClause = #acc<data_clause acc_copyout>, name = "a"}
! ALL: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)

! ALL: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>)
! FIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) to varPtr(%[[ADECL]] : !fir.ref<!fir.array<100xi32>>) {name = "a"}
! HLFIR: acc.copyout accPtr(%[[CREATE]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) to varPtr(%[[ADECL]]#1 : !fir.ref<!fir.array<100xi32>>) {name = "a"}
! ALL: return
Expand Down Expand Up @@ -178,9 +178,9 @@ subroutine acc_declare_device_resident(a)
! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare<dataClause = acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_residentEa"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xi32>>, !fir.ref<!fir.array<100xi32>>)
! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xi32>> {name = "a"}
! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
! ALL: %{{.*}}:2 = fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%arg{{.*}} = %{{.*}}) -> (index, i32)
! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>)
! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xi32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}

subroutine acc_declare_device_resident2()
Expand All @@ -195,8 +195,8 @@ subroutine acc_declare_device_resident2()
! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOCA]](%{{.*}}) {acc.declare = #acc.declare<dataClause = acc_declare_device_resident>, uniq_name = "_QMacc_declareFacc_declare_device_resident2Edataparam"} : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<100xf32>>, !fir.ref<!fir.array<100xf32>>)
! FIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
! HLFIR: %[[DEVICERES:.*]] = acc.declare_device_resident varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
! ALL: acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
! ALL: acc.declare_exit dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
! ALL: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
! ALL: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>)
! ALL: acc.delete accPtr(%[[DEVICERES]] : !fir.ref<!fir.array<100xf32>>) bounds(%{{.*}}) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "dataparam"}

subroutine acc_declare_link2()
Expand Down Expand Up @@ -234,10 +234,10 @@ function acc_declare_in_func()

! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 {
! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
! HLFIR: acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)

! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref<f32>
! HLFIR: acc.declare_exit dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) bounds(%6) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
! HLFIR: return %[[LOAD]] : f32
! ALL: }
Expand All @@ -254,10 +254,10 @@ function acc_declare_in_func2(i)
! HLFIR: %[[ALLOCA_A:.*]] = fir.alloca !fir.array<1024xf32> {bindc_name = "a", uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"}
! HLFIR: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOCA_A]](%{{.*}}) {acc.declare = #acc.declare<dataClause = acc_create>, uniq_name = "_QMacc_declareFacc_declare_in_func2Ea"} : (!fir.ref<!fir.array<1024xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xf32>>, !fir.ref<!fir.array<1024xf32>>)
! HLFIR: %[[CREATE:.*]] = acc.create varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%7) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
! HLFIR: acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: cf.br ^bb1
! HLFIR: ^bb1:
! HLFIR: acc.declare_exit dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>)
! HLFIR: acc.delete accPtr(%[[CREATE]] : !fir.ref<!fir.array<1024xf32>>) bounds(%7) {dataClause = #acc<data_clause acc_create>, name = "a"}
! ALL: return %{{.*}} : f32
! ALL: }
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
}];

let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
let results = (outs OpenACC_DeclareTokenType:$token);

let assemblyFormat = [{
oilist(
Expand All @@ -1441,7 +1442,7 @@ def OpenACC_DeclareEnterOp : OpenACC_Op<"declare_enter", []> {
let hasVerifier = 1;
}

def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", [AttrSizedOperandSegments]> {
let summary = "declare directive - exit from implicit data region";

let description = [{
Expand All @@ -1458,10 +1459,13 @@ def OpenACC_DeclareExitOp : OpenACC_Op<"declare_exit", []> {
```
}];

let arguments = (ins Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);
let arguments = (ins
Optional<OpenACC_DeclareTokenType>:$token,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands);

let assemblyFormat = [{
oilist(
`token` `(` $token `)` |
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
attr-dict-with-keyword
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,13 @@ def OpenACC_DataBoundsType : OpenACC_Type<"DataBounds", "data_bounds_ty"> {
let summary = "Type for representing acc data clause bounds information";
}

def OpenACC_DeclareTokenType : OpenACC_Type<"DeclareToken", "declare_token"> {
let summary = "declare token type";
let description = [{
`acc.declare_token` is a type returned by a `declare_enter` operation and
can be passed to a `declare_exit` operation to represent an implicit
data region.
}];
}

#endif // OPENACC_OPS_TYPES
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,9 +1087,10 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
//===----------------------------------------------------------------------===//

template <typename Op>
static LogicalResult checkDeclareOperands(Op &op,
const mlir::ValueRange &operands) {
if (operands.empty())
static LogicalResult
checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
bool requireAtLeastOneOperand = true) {
if (operands.empty() && requireAtLeastOneOperand)
return emitError(
op->getLoc(),
"at least one operand must appear on the declare operation");
Expand Down Expand Up @@ -1151,6 +1152,9 @@ LogicalResult acc::DeclareEnterOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult acc::DeclareExitOp::verify() {
if (getToken())
return checkDeclareOperands(*this, this->getDataClauseOperands(),
/*requireAtLeastOneOperand=*/false);
return checkDeclareOperands(*this, this->getDataClauseOperands());
}

Expand Down

0 comments on commit 9365ed1

Please sign in to comment.