Skip to content

Commit

Permalink
[mlir][openacc][flang] Support wait devnum and clean async/wait IR (#…
Browse files Browse the repository at this point in the history
…79525)

- Support wait(devnum: ) with device_type support on all operations that
require it
- devnum value is stored as the first value of waitOperands in its
device_type sub-segment. The hasWaitDevnum attribute inform which
sub-segment has a wait(devnum) value.
- Make async/wait information homogenous on compute ops, data and update
op.
- Unify operands/attributes names across operations and use the same
custom parser/printer
  • Loading branch information
clementval committed Jan 29, 2024
1 parent 86b3f85 commit c09dc2d
Show file tree
Hide file tree
Showing 15 changed files with 381 additions and 187 deletions.
93 changes: 57 additions & 36 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);

Expand Down Expand Up @@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
Expand Down Expand Up @@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}

static void
genWaitClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
llvm::SmallVector<mlir::Value> &waitOperands,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<int32_t> &waitOperandsSegments,
mlir::Value &waitDevnum,
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
static void genWaitClauseWithDeviceType(
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
llvm::SmallVector<mlir::Value> &waitOperands,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<bool> &hasDevnums,
llvm::SmallVector<int32_t> &waitOperandsSegments,
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
llvm::SmallVector<mlir::Value> waitValues;

const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
bool hasDevnum = false;
if (waitDevnumValue) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
hasDevnum = true;
}

const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}

for (auto deviceTypeAttr : deviceTypeAttrs) {
for (auto value : waitValues)
waitOperands.push_back(value);
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitValues.size());
hasDevnums.push_back(hasDevnum);
}

// TODO: move to device_type model.
const auto &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
for (auto deviceTypeAttr : deviceTypeAttrs)
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
Expand Down Expand Up @@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
llvm::SmallVector<bool> hasWaitDevnums;

llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
mlir::Value waitDevnum; // TODO not yet implemented on compute op.

// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
Expand Down Expand Up @@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
Expand Down Expand Up @@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
computeOp.setAsyncOperandsDeviceTypeAttr(
builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));

Expand All @@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!hasWaitDevnums.empty())
computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

Expand Down Expand Up @@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;
llvm::SmallVector<bool> hasWaitDevnums;

bool hasDefaultNone = false;
bool hasDefaultPresent = false;
Expand Down Expand Up @@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
Expand All @@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);

Expand All @@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
operandSegments);

if (!asyncDeviceTypes.empty())
dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
dataOp.setAsyncOperandsDeviceTypeAttr(
builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
Expand All @@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!hasWaitDevnums.empty())
dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

Expand Down Expand Up @@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
}

static inline mlir::ArrayAttr
getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
return values.empty() ? nullptr : b.getBoolArrayAttr(values);
}

static inline mlir::DenseI32ArrayAttr
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
llvm::SmallVector<int32_t> &values) {
Expand All @@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
waitOperands, deviceTypeOperands, asyncOperands;
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<int32_t> waitOperandsSegments;

fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Expand Down Expand Up @@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
Expand Down Expand Up @@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
builder.create<mlir::acc::UpdateOp>(
currentLocation, ifCond, asyncOperands,
getArrayAttr(builder, asyncOperandsDeviceTypes),
getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
getDenseI32ArrayAttr(builder, waitOperandsSegments),
getArrayAttr(builder, waitOperandsDeviceTypes),
getBoolArrayAttr(builder, hasWaitDevnums),
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
ifPresent);

Expand Down Expand Up @@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);

Expand Down Expand Up @@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenACC/acc-data.f90
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ subroutine acc_data
!$acc data present(a) wait
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) {
! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK: acc.data dataOperands(%{{.*}}) wait {
! CHECK: }

!$acc data present(a) wait(1)
!$acc end data
Expand All @@ -176,7 +176,7 @@ subroutine acc_data
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: }{{$}}

!$acc data default(none)
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-kernels-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ subroutine acc_kernels_loop
a(i) = b(i)
END DO

! CHECK: acc.kernels {
! CHECK: acc.kernels wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc kernels loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-kernels.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ subroutine acc_kernels
!$acc kernels wait
!$acc end kernels

! CHECK: acc.kernels {
! CHECK: acc.kernels wait {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc kernels wait(1)
!$acc end kernels
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-parallel-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ subroutine acc_parallel_loop
a(i) = b(i)
END DO

! CHECK: acc.parallel {
! CHECK: acc.parallel wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc parallel loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-parallel.f90
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ subroutine acc_parallel
!$acc parallel wait
!$acc end parallel

! CHECK: acc.parallel {
! CHECK: acc.parallel wait {
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc parallel wait(1)
!$acc end parallel
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-serial-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ subroutine acc_serial_loop
a(i) = b(i)
END DO

! CHECK: acc.serial {
! CHECK: acc.serial wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc serial loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-serial.f90
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ subroutine acc_serial
!$acc serial wait
!$acc end serial

! CHECK: acc.serial {
! CHECK: acc.serial wait {
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc serial wait(1)
!$acc end serial
Expand Down
5 changes: 1 addition & 4 deletions flang/test/Lower/OpenACC/acc-update.f90
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ subroutine acc_update

!$acc update host(a) wait(devnum: 1: queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}

!$acc update host(a) device_type(host, nvidia) async
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/Frontend/Directive/DirectiveBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class Directive<string d> {
// allowedClauses and requiredClauses lists.

// List of allowed clauses for the directive.
list<VersionedClause> allowedClauses = [];
list<VersionedClause> allowedClauses = [];

// List of clauses that are allowed to appear only once.
list<VersionedClause> allowedOnceClauses = [];
Expand Down
Loading

0 comments on commit c09dc2d

Please sign in to comment.