Skip to content

Commit

Permalink
[mlir][openacc] Update acc.loop to be a proper loop like operation (#…
Browse files Browse the repository at this point in the history
…67355)

The initial design of the `acc.loop` was to be an operation that
encapsulates a loop like operation. This was an early design and we now
want to change it so the `acc.loop` operation becomes a real loop-like
operation by implementing the LoopLikeInterface.

Differential Revision: https://reviews.llvm.org/D159229

This patch is just moved from Phabricator to github
  • Loading branch information
clementval committed Jan 22, 2024
1 parent c0a74ad commit 3eb4178
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 230 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.h.inc"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#define GET_TYPEDEF_CLASSES
Expand Down
33 changes: 21 additions & 12 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define OPENACC_OPS

include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/EnumAttr.td"
Expand Down Expand Up @@ -1474,29 +1475,34 @@ def OpenACC_HostDataOp : OpenACC_Op<"host_data",

def OpenACC_LoopOp : OpenACC_Op<"loop",
[AttrSizedOperandSegments, RecursiveMemoryEffects,
MemoryEffects<[MemWrite<OpenACC_ConstructResource>]>]> {
MemoryEffects<[MemWrite<OpenACC_ConstructResource>]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
let summary = "loop construct";

let description = [{
The "acc.loop" operation represents the OpenACC loop construct.
The "acc.loop" operation represents the OpenACC loop construct. The lower
and upper bounds specify a half-open range: the range includes the lower
bound but does not include the upper bound. If the `inclusive` attribute is
set then the upper bound is included.

Example:

```mlir
acc.loop gang vector {
scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
// ... body
}
}
}
acc.loop gang() vector() (%arg3 : index, %arg4 : index, %arg5 : index) =
(%c0, %c0, %c0 : index, index, index) to
(%c10, %c10, %c10 : index, index, index) step
(%c1, %c1, %c1 : index, index, index) {
// Loop body
acc.yield
} attributes { collapse = [3] }
```
}];

let arguments = (ins
Variadic<IntOrIndex>:$lowerbound,
Variadic<IntOrIndex>:$upperbound,
Variadic<IntOrIndex>:$step,
OptionalAttr<DenseBoolArrayAttr>:$inclusiveUpperbound,
OptionalAttr<I64ArrayAttr>:$collapse,
OptionalAttr<DeviceTypeArrayAttr>:$collapseDeviceType,
Variadic<IntOrIndex>:$gangOperands,
Expand All @@ -1521,7 +1527,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$reductionOperands,
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
);
);

let results = (outs Variadic<AnyType>:$results);

Expand All @@ -1539,6 +1545,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);

Block &getBody() { return getLoopRegions().front()->front(); }

/// Return true if the op has the auto attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasAuto();
Expand Down Expand Up @@ -1628,7 +1636,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
`)`
| `cache` `(` $cacheOperands `:` type($cacheOperands) `)`
)
$region
custom<LoopControl>($region, $lowerbound, type($lowerbound), $upperbound,
type($upperbound), $step, type($step))
( `(` type($results)^ `)` )?
attr-dict-with-keyword
}];
Expand Down
110 changes: 89 additions & 21 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,17 +1008,12 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
bool needCommaBeforeOperands = false;

// Keyword only
if (failed(parser.parseOptionalLParen())) {
keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
keywordOnlyDeviceType =
ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
return success();
}
if (failed(parser.parseOptionalLParen()))
return failure();

// Parse keyword only attributes
if (succeeded(parser.parseOptionalLSquare())) {
// Parse keyword only attributes
if (failed(parser.parseCommaSeparatedList([&]() {
if (parser.parseAttribute(
keywordOnlyDeviceTypeAttributes.emplace_back()))
Expand All @@ -1029,6 +1024,13 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
if (parser.parseRSquare())
return failure();
needCommaBeforeOperands = true;
} else if (succeeded(parser.parseOptionalRParen())) {
// Keyword only
keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
keywordOnlyDeviceType =
ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
return success();
}

if (needCommaBeforeOperands && failed(parser.parseComma()))
Expand Down Expand Up @@ -1065,15 +1067,18 @@ static void printDeviceTypeOperandsWithKeywordOnly(
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {

p << "(";

if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
keywordOnlyDeviceTypes->size() == 1) {
auto deviceTypeAttr =
mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) {
p << ")";
return;
}
}

p << "(";
printDeviceTypes(p, keywordOnlyDeviceTypes);
if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
hasDeviceTypeValues(deviceTypes))
Expand Down Expand Up @@ -1323,17 +1328,12 @@ static ParseResult parseGangClause(
bool needCommaBetweenValues = false;
bool needCommaBeforeOperands = false;

// Gang only keyword
if (failed(parser.parseOptionalLParen())) {
gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
gangOnlyDeviceType =
ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
return success();
}
if (failed(parser.parseOptionalLParen()))
return failure();

// Parse gang only attributes
if (succeeded(parser.parseOptionalLSquare())) {
// Parse gang only attributes
if (failed(parser.parseCommaSeparatedList([&]() {
if (parser.parseAttribute(
gangOnlyDeviceTypeAttributes.emplace_back()))
Expand All @@ -1344,6 +1344,13 @@ static ParseResult parseGangClause(
if (parser.parseRSquare())
return failure();
needCommaBeforeOperands = true;
} else if (succeeded(parser.parseOptionalRParen())) {
// Gang only keyword
gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
gangOnlyDeviceType =
ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
return success();
}

auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
Expand Down Expand Up @@ -1443,16 +1450,18 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
std::optional<mlir::DenseI32ArrayAttr> segments,
std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {

p << "(";
if (operands.begin() == operands.end() &&
hasDeviceTypeValues(gangOnlyDeviceTypes) &&
gangOnlyDeviceTypes->size() == 1) {
auto deviceTypeAttr =
mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) {
p << ")";
return;
}
}

p << "(";
printDeviceTypes(p, gangOnlyDeviceTypes);

if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
Expand Down Expand Up @@ -1516,6 +1525,11 @@ LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
}

LogicalResult acc::LoopOp::verify() {
if (!getUpperbound().empty() && getInclusiveUpperbound() &&
(getUpperbound().size() != getInclusiveUpperbound()->size()))
return emitError() << "inclusiveUpperbound size is expected to be the same"
<< " as upperbound size";

// Check collapse
if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
return emitOpError() << "collapse device_type attr must be define when"
Expand Down Expand Up @@ -1629,7 +1643,9 @@ unsigned LoopOp::getNumDataOperands() {
}

Value LoopOp::getDataOperand(unsigned i) {
unsigned numOptional = getGangOperands().size();
unsigned numOptional =
getLowerbound().size() + getUpperbound().size() + getStep().size();
numOptional += getGangOperands().size();
numOptional += getVectorOperands().size();
numOptional += getWorkerNumOperands().size();
numOptional += getTileOperands().size();
Expand Down Expand Up @@ -1748,6 +1764,58 @@ bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
return hasDeviceType(getGang(), deviceType);
}

llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
return {&getRegion()};
}

/// loop-control ::= `(` ssa-id-and-type-list `)` `=` `(` ssa-id-and-type-list
/// `)` `to` `(` ssa-id-and-type-list `)` `step` `(` ssa-id-and-type-list `)`
ParseResult
parseLoopControl(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
SmallVectorImpl<Type> &lowerboundType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
SmallVectorImpl<Type> &upperboundType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
SmallVectorImpl<Type> &stepType) {

SmallVector<OpAsmParser::Argument> inductionVars;
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
/*allowType=*/true) ||
parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
parser.parseOperandList(lowerbound, inductionVars.size(),
OpAsmParser::Delimiter::None) ||
parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
parser.parseKeyword("to") || parser.parseLParen() ||
parser.parseOperandList(upperbound, inductionVars.size(),
OpAsmParser::Delimiter::None) ||
parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
parser.parseKeyword("step") || parser.parseLParen() ||
parser.parseOperandList(step, inductionVars.size(),
OpAsmParser::Delimiter::None) ||
parser.parseColonTypeList(stepType) || parser.parseRParen())
return failure();
}
return parser.parseRegion(region, inductionVars);
}

void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange lowerbound, TypeRange lowerboundType,
ValueRange upperbound, TypeRange upperboundType,
ValueRange steps, TypeRange stepType) {
ValueRange regionArgs = region.front().getArguments();
if (!regionArgs.empty()) {
p << "(";
llvm::interleaveComma(regionArgs, p,
[&p](Value v) { p << v << " : " << v.getType(); });
p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
<< upperbound << " : " << upperboundType << ") "
<< " step (" << steps << " : " << stepType << ") ";
}
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 6 additions & 4 deletions mlir/test/Dialect/OpenACC/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {

func.func @testhostdataop(%a: memref<f32>, %ifCond: i1) -> () {
%0 = acc.use_device varPtr(%a : memref<f32>) -> memref<f32>
%1 = arith.constant 1 : i32
%2 = arith.constant 10 : i32
%false = arith.constant false
acc.host_data dataOperands(%0 : memref<f32>) if(%false) {
acc.loop {
acc.loop (%iv : i32) = (%1 : i32) to (%2 : i32) step (%1 : i32) {
acc.yield
}
acc.loop {
} attributes { inclusiveUpperbound = array<i1: true> }
acc.loop (%iv : i32) = (%1 : i32) to (%2 : i32) step (%1 : i32) {
acc.yield
}
} attributes { inclusiveUpperbound = array<i1: true> }
acc.terminator
}
return
Expand Down
Loading

0 comments on commit 3eb4178

Please sign in to comment.