Skip to content

Commit

Permalink
[mlir][openacc] Add device_type support for data operation (#76126)
Browse files Browse the repository at this point in the history
Following #75864, this patch adds device_type support to the data
operation on the async and wait operands and attributes.
  • Loading branch information
clementval committed Jan 5, 2024
1 parent 3096353 commit 71ec301
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 55 deletions.
125 changes: 88 additions & 37 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,24 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
}
}

static void
genAsyncClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Async *asyncClause,
llvm::SmallVector<mlir::Value> &async,
llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
mlir::acc::DeviceTypeAttr deviceTypeAttr,
Fortran::lower::StatementContext &stmtCtx) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
asyncDeviceTypes.push_back(deviceTypeAttr);
} else {
asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}

static mlir::acc::DeviceType
getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
switch (device) {
Expand Down Expand Up @@ -1533,6 +1551,39 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}

static void
genWaitClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
llvm::SmallVector<mlir::Value> &waitOperands,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<int32_t> &waitOperandsSegments,
mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
auto crtWaitOperands = waitOperands.size();
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitOperands.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);

// TODO: move to device_type model.
const auto &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}

static mlir::acc::LoopOp
createLoopOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Expand Down Expand Up @@ -1795,6 +1846,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
mlir::Value waitDevnum; // TODO not yet implemented on compute op.

// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
Expand All @@ -1818,31 +1870,14 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
} else {
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
auto crtWaitOperands = waitOperands.size();
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitOperands.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
} else {
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
Expand Down Expand Up @@ -2126,21 +2161,24 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
mlir::Value ifCond, async, waitDevnum;
mlir::Value ifCond, waitDevnum;
llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;

// Async and wait have an optional value but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
bool addAsyncAttr = false;
bool addWaitAttr = false;
copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
async;
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;

bool hasDefaultNone = false;
bool hasDefaultPresent = false;

fir::FirOpBuilder &builder = converter.getFirOpBuilder();

// device_type attribute is set to `none` until a device_type clause is
// encountered.
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);

// Lower clauses values mapped to operands.
// Keep track of each group of operands separately as clauses can appear
// more than once.
Expand Down Expand Up @@ -2221,11 +2259,14 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.end());
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
addWaitAttr, stmtCtx);
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
Expand All @@ -2239,7 +2280,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> operands;
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperand(operands, operandSegments, async);
addOperands(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);
Expand All @@ -2250,8 +2291,18 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, eval, operands, operandSegments);

dataOp.setAsyncAttr(addAsyncAttr);
dataOp.setWaitAttr(addWaitAttr);
if (!asyncDeviceTypes.empty())
dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
dataOp.setWaitOperandsDeviceTypeAttr(
builder.getArrayAttr(waitOperandsDeviceTypes));
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

if (hasDefaultNone)
dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/OpenACC/acc-data.f90
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ subroutine acc_data
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) {
! CHECK: } attributes {asyncAttr}
! CHECK: } attributes {asyncOnly = [#acc.device_type<none>]}

!$acc data present(a) async(1)
!$acc end data
Expand All @@ -165,18 +165,18 @@ subroutine acc_data
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) {
! CHECK: } attributes {waitAttr}
! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}

!$acc data present(a) wait(1)
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) wait(%{{.*}} : i32) {
! CHECK: acc.data dataOperands(%{{.*}}) wait({%{{.*}} : i32}) {
! CHECK: }{{$}}

!$acc data present(a) wait(devnum: 0: 1)
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait(%{{.*}} : i32) {
! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
! CHECK: }{{$}}

!$acc data default(none)
Expand Down
47 changes: 38 additions & 9 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1236,13 +1236,16 @@ def OpenACC_DataOp : OpenACC_Op<"data",


let arguments = (ins Optional<I1>:$ifCond,
Optional<IntOrIndex>:$async,
UnitAttr:$asyncAttr,
Optional<IntOrIndex>:$waitDevnum,
Variadic<IntOrIndex>:$waitOperands,
UnitAttr:$waitAttr,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
Variadic<IntOrIndex>:$async,
OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Optional<IntOrIndex>:$waitDevnum,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);

let regions = (region AnyRegion:$region);

Expand All @@ -1252,15 +1255,41 @@ def OpenACC_DataOp : OpenACC_Op<"data",

/// The i-th data operand passed.
Value getDataOperand(unsigned i);

/// Return true if the op has the async attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasAsyncOnly();
/// Return true if the op has the async attribute for the given device_type.
bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
/// Return the value of the async clause if present.
mlir::Value getAsyncValue();
/// Return the value of the async clause for the given device_type if
/// present.
mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);

/// Return true if the op has the wait attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasWaitOnly();
/// Return true if the op has the wait attribute for the given device_type.
bool hasWaitOnly(mlir::acc::DeviceType deviceType);
/// Return the values of the wait clause if present.
mlir::Operation::operand_range getWaitValues();
/// Return the values of the wait clause for the given device_type if
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
}];

let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
| `async` `(` $async `:` type($async) `)`
| `async` `(` custom<DeviceTypeOperands>($async,
type($async), $asyncDeviceType) `)`
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
type($waitOperands), $waitOperandsDeviceType,
$waitOperandsSegments) `)`
)
$region attr-dict-with-keyword
}];
Expand Down
43 changes: 42 additions & 1 deletion mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,11 +1417,52 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }

Value DataOp::getDataOperand(unsigned i) {
unsigned numOptional = getIfCond() ? 1 : 0;
numOptional += getAsync() ? 1 : 0;
numOptional += getAsync().size() ? 1 : 0;
numOptional += getWaitOperands().size();
return getOperand(numOptional + i);
}

bool acc::DataOp::hasAsyncOnly() {
return hasAsyncOnly(mlir::acc::DeviceType::None);
}

bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAsyncOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
}

mlir::Value DataOp::getAsyncValue() {
return getAsyncValue(mlir::acc::DeviceType::None);
}

mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
deviceType);
}

bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }

bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWaitOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
}

mlir::Operation::operand_range DataOp::getWaitValues() {
return getWaitValues(mlir::acc::DeviceType::None);
}

mlir::Operation::operand_range
DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
getWaitOperandsSegments(), deviceType);
}

//===----------------------------------------------------------------------===//
// ExitDataOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/OpenACC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,11 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }

%w1 = arith.constant 1 : i64
acc.data wait(%w1 : i64) {
acc.data wait({%w1 : i64}) {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }

%wd1 = arith.constant 1 : i64
acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) {
acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
} attributes { defaultAttr = #acc<defaultvalue none>, wait }

return
Expand Down Expand Up @@ -951,10 +951,10 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
// CHECK: acc.data {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}

// CHECK: acc.data wait(%{{.*}} : i64) {
// CHECK: acc.data wait({%{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}

// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) {
// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}

// -----
Expand Down

0 comments on commit 71ec301

Please sign in to comment.