Skip to content

Commit

Permalink
Revert "[mlir][openacc] Add device_type support for compute operations (
Browse files Browse the repository at this point in the history
#75864)"

This reverts commit 8b885eb.
  • Loading branch information
clementval committed Dec 21, 2023
1 parent 7c9c807 commit 10df608
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 1,216 deletions.
106 changes: 22 additions & 84 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
return mlir::acc::DeviceType::Multicore;
}
return mlir::acc::DeviceType::None;
return mlir::acc::DeviceType::Default;
}

static void gatherDeviceTypeAttrs(
Expand Down Expand Up @@ -1781,89 +1781,60 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
bool outerCombined = false) {

// Parallel operation operands
mlir::Value async;
mlir::Value numWorkers;
mlir::Value vectorLength;
mlir::Value ifCond;
mlir::Value selfCond;
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
dataClauseOperands, numGangs;

llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;

// Self clause has optional values but can be present with
// Async, wait and self clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
bool addAsyncAttr = false;
bool addWaitAttr = false;
bool addSelfAttr = false;

bool hasDefaultNone = false;
bool hasDefaultPresent = false;

fir::FirOpBuilder &builder = converter.getFirOpBuilder();

// device_type attribute is set to `none` until a device_type clause is
// encountered.
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);

// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
} else {
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
auto crtWaitOperands = waitOperands.size();
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitOperands.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
} else {
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
addWaitAttr, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
auto crtNumGangs = numGangs.size();
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
numGangs.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
numWorkers.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
numWorkers = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
vectorLength.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
vectorLength = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
Expand Down Expand Up @@ -2014,27 +1985,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
else if ((defaultClause->v).v ==
llvm::acc::DefaultValue::ACC_Default_present)
hasDefaultPresent = true;
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
assert(deviceTypeExprList.v.size() == 1 &&
"expect only one device_type expr");
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}

// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 8> operands;
llvm::SmallVector<int32_t, 8> operandSegments;
addOperands(operands, operandSegments, async);
addOperand(operands, operandSegments, async);
addOperands(operands, operandSegments, waitOperands);
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
addOperands(operands, operandSegments, numGangs);
addOperands(operands, operandSegments, numWorkers);
addOperands(operands, operandSegments, vectorLength);
addOperand(operands, operandSegments, numWorkers);
addOperand(operands, operandSegments, vectorLength);
}
addOperand(operands, operandSegments, ifCond);
addOperand(operands, operandSegments, selfCond);
Expand All @@ -2055,6 +2017,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder, currentLocation, eval, operands, operandSegments,
outerCombined);

if (addAsyncAttr)
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
if (addWaitAttr)
computeOp.setWaitAttrAttr(builder.getUnitAttr());
if (addSelfAttr)
computeOp.setSelfAttrAttr(builder.getUnitAttr());

Expand All @@ -2063,34 +2029,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (hasDefaultPresent)
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);

if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
if (!numWorkersDeviceTypes.empty())
computeOp.setNumWorkersDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
if (!vectorLengthDeviceTypes.empty())
computeOp.setVectorLengthDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
if (!numGangsDeviceTypes.empty())
computeOp.setNumGangsDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
if (!numGangsSegments.empty())
computeOp.setNumGangsSegmentsAttr(
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));

if (!waitOperandsDeviceTypes.empty())
computeOp.setWaitOperandsDeviceTypeAttr(
builder.getArrayAttr(waitOperandsDeviceTypes));
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
if (!privatizations.empty())
computeOp.setPrivatizationsAttr(
Expand Down
44 changes: 0 additions & 44 deletions flang/test/Lower/OpenACC/acc-device-type.f90

This file was deleted.

14 changes: 7 additions & 7 deletions flang/test/Lower/OpenACC/acc-kernels-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
! CHECK-NEXT: } attributes {asyncAttr}

!$acc kernels loop async(1)
DO i = 1, n
Expand Down Expand Up @@ -103,15 +103,15 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: } attributes {waitAttr}

!$acc kernels loop wait(1)
DO i = 1, n
a(i) = b(i)
END DO

! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -126,7 +126,7 @@ subroutine acc_kernels_loop

! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -141,7 +141,7 @@ subroutine acc_kernels_loop

! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -155,7 +155,7 @@ subroutine acc_kernels_loop
END DO

! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -169,7 +169,7 @@ subroutine acc_kernels_loop
END DO

! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand Down
14 changes: 7 additions & 7 deletions flang/test/Lower/OpenACC/acc-kernels.f90
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ subroutine acc_kernels

! CHECK: acc.kernels {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
! CHECK-NEXT: } attributes {asyncAttr}

!$acc kernels async(1)
!$acc end kernels
Expand All @@ -63,13 +63,13 @@ subroutine acc_kernels

! CHECK: acc.kernels {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: } attributes {waitAttr}

!$acc kernels wait(1)
!$acc end kernels

! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand All @@ -78,7 +78,7 @@ subroutine acc_kernels

! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand All @@ -87,23 +87,23 @@ subroutine acc_kernels

! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

!$acc kernels num_gangs(1)
!$acc end kernels

! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

!$acc kernels num_gangs(numGangs)
!$acc end kernels

! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand Down

0 comments on commit 10df608

Please sign in to comment.