Skip to content

Commit

Permalink
[flang][OpenACC] Lower enter data directive
Browse files Browse the repository at this point in the history
This patch adds lowering for the `!$acc enter data` directive
from the PFT to OpenACC dialect.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D122384
  • Loading branch information
clementval committed Mar 24, 2022
1 parent 30b4421 commit 12d22ce
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 37 deletions.
8 changes: 6 additions & 2 deletions flang/lib/Lower/Bridge.cpp
Expand Up @@ -1089,8 +1089,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "CompilerDirective lowering");
}

void genFIR(const Fortran::parser::OpenACCConstruct &) {
TODO(toLocation(), "OpenACCConstruct lowering");
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
genOpenACCConstruct(*this, getEval(), acc);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
builder->restoreInsertionPoint(insertPt);
}

void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &) {
Expand Down
94 changes: 59 additions & 35 deletions flang/lib/Lower/OpenACC.cpp
Expand Up @@ -120,6 +120,56 @@ static Op createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
return op;
}

static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Async *asyncClause,
mlir::Value &async, bool &addAsyncAttr,
Fortran::lower::StatementContext &stmtCtx) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
} else {
addAsyncAttr = true;
}
}

static void genIfClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::If *ifClause,
mlir::Value &ifCond,
Fortran::lower::StatementContext &stmtCtx) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Value cond = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
ifCond = firOpBuilder.createConvert(converter.getCurrentLocation(),
firOpBuilder.getI1Type(), cond);
}

static void genWaitClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
SmallVectorImpl<mlir::Value> &operands,
mlir::Value &waitDevnum, bool &addWaitAttr,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
mlir::Value v = fir::getBase(
converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx));
operands.push_back(v);
}

const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
addWaitAttr = true;
}
}

static void genACC(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
Expand Down Expand Up @@ -540,7 +590,7 @@ static void
genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClauseList &accClauseList) {
mlir::Value ifCond, async, waitDevnum;
SmallVector<Value, 2> copyinOperands, createOperands, createZeroOperands,
SmallVector<mlir::Value> copyinOperands, createOperands, createZeroOperands,
attachOperands, waitOperands;

// Async, wait and self clause have optional values but can be present with
Expand All @@ -549,50 +599,24 @@ genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
bool addAsyncAttr = false;
bool addWaitAttr = false;

auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
Fortran::lower::StatementContext stmtCtx;

// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
for (const auto &clause : accClauseList.v) {
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
mlir::Value cond = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
ifCond = firOpBuilder.createConvert(currentLocation,
firOpBuilder.getI1Type(), cond);
genIfClause(converter, ifClause, ifCond, stmtCtx);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
} else {
addAsyncAttr = true;
}
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
mlir::Value v = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx));
waitOperands.push_back(v);
}

const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
addWaitAttr = true;
}
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
addWaitAttr, stmtCtx);
} else if (const auto *copyinClause =
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier =
Expand Down Expand Up @@ -627,7 +651,7 @@ genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
addOperands(operands, operandSegments, createZeroOperands);
addOperands(operands, operandSegments, attachOperands);

auto enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
firOpBuilder, currentLocation, operands, operandSegments);

if (addAsyncAttr)
Expand Down
69 changes: 69 additions & 0 deletions flang/test/Lower/OpenACC/acc-enter-data.f90
@@ -0,0 +1,69 @@
! This test checks lowering of OpenACC enter data directive.

! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s

subroutine acc_enter_data
integer :: async = 1
real, dimension(10, 10) :: a, b, c
real, pointer :: d
logical :: ifCondition = .TRUE.

!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ec"}
!CHECK: [[D:%.*]] = fir.alloca !fir.box<!fir.ptr<f32>> {bindc_name = "d", uniq_name = "{{.*}}Ed"}

!$acc enter data create(a)
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}

!$acc enter data create(a) if(.true.)
!CHECK: [[IF1:%.*]] = arith.constant true
!CHECK: acc.enter_data if([[IF1]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}

!$acc enter data create(a) if(ifCondition)
!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
!CHECK: acc.enter_data if([[IF2]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}

!$acc enter data create(a) create(b) create(c)
!CHECK: acc.enter_data create([[A]], [[B]], [[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}

!$acc enter data create(a) create(b) create(zero: c)
!CHECK: acc.enter_data create([[A]], [[B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) create_zero([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}

!$acc enter data copyin(a) create(b) attach(d)
!CHECK: acc.enter_data copyin([[A]] : !fir.ref<!fir.array<10x10xf32>>) create([[B]] : !fir.ref<!fir.array<10x10xf32>>) attach([[D]] : !fir.ref<!fir.box<!fir.ptr<f32>>>){{$}}

!$acc enter data create(a) async
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}

!$acc enter data create(a) wait
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}

!$acc enter data create(a) async wait
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}

!$acc enter data create(a) async(1)
!CHECK: [[ASYNC1:%.*]] = arith.constant 1 : i32
!CHECK: acc.enter_data async([[ASYNC1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data create(a) async(async)
!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
!CHECK: acc.enter_data async([[ASYNC2]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data create(a) wait(1)
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
!CHECK: acc.enter_data wait([[WAIT1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data create(a) wait(queues: 1, 2)
!CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
!CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
!CHECK: acc.enter_data wait([[WAIT2]], [[WAIT3]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)

!$acc enter data create(a) wait(devnum: 1: queues: 1, 2)
!CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
!CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
!CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
!CHECK: acc.enter_data wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)

end subroutine acc_enter_data

0 comments on commit 12d22ce

Please sign in to comment.