Skip to content

Commit

Permalink
[MLIR][Affine] Add affine.parallel op
Browse files Browse the repository at this point in the history
Summary:
As discussed in https://llvm.discourse.group/t/rfc-add-affine-parallel/350, this is the first in a series of patches to bring in support for the `affine.parallel` operation.

This first patch adds the IR representation along with custom printer/parser implementations.

Reviewers: bondhugula, herhut, mehdi_amini, nicolasvasilache, rriddle, earhart, jbruestle

Reviewed By: bondhugula, nicolasvasilache, rriddle, earhart, jbruestle

Subscribers: jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74288
  • Loading branch information
Frank Laub committed Feb 13, 2020
1 parent c662795 commit fdc7a16
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 20 deletions.
73 changes: 73 additions & 0 deletions mlir/include/mlir/Dialect/AffineOps/AffineOps.td
Expand Up @@ -271,6 +271,79 @@ def AffineMaxOp : AffineMinMaxOpBase<"max", [NoSideEffect]> {
}];
}

def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
let summary = "multi-index parallel band operation";
let description = [{
The "affine.parallel" operation represents a hyper-rectangular affine
parallel band, defining multiple SSA values for its induction variables. It
has one region capturing the parallel band body. The induction variables are
represented as arguments of this region. These SSA values always have type
index, which is the size of the machine word. The strides, represented by
steps, are positive constant integers which defaults to "1" if not present.
The lower and upper bounds specify a half-open range: the range includes the
lower bound but does not include the upper bound. The body region must
contain exactly one block that terminates with "affine.terminator".

The lower and upper bounds of a parallel operation are represented as an
application of an affine mapping to a list of SSA values passed to the map.
The same restrictions hold for these SSA values as for all bindings of SSA
values to dimensions and symbols.

Note: Calling AffineParallelOp::build will create the required region and
block, and insert the required terminator. Parsing will also create the
required region, block, and terminator, even when they are missing from the
textual representation.

Example:

```mlir
affine.parallel (%i, %j) = (0, 0) to (10, 10) step (1, 1) {
...
}
```
}];

let arguments = (ins
AffineMapAttr:$lowerBoundsMap,
AffineMapAttr:$upperBoundsMap,
I64ArrayAttr:$steps,
Variadic<Index>:$mapOperands);
let regions = (region SizedRegion<1>:$region);

let builders = [
OpBuilder<"Builder* builder, OperationState& result,"
"ArrayRef<int64_t> ranges">,
OpBuilder<"Builder* builder, OperationState& result, AffineMap lbMap,"
"ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs">,
OpBuilder<"Builder* builder, OperationState& result, AffineMap lbMap,"
"ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs,"
"ArrayRef<int64_t> steps">
];

let extraClassDeclaration = [{
/// Get the number of dimensions.
unsigned getNumDims();

operand_range getLowerBoundsOperands();
operand_range getUpperBoundsOperands();

AffineValueMap getLowerBoundsValueMap();
AffineValueMap getUpperBoundsValueMap();
AffineValueMap getRangesValueMap();

/// Get ranges as constants, may fail in dynamic case.
Optional<SmallVector<int64_t, 8>> getConstantRanges();

Block *getBody();
OpBuilder getBodyBuilder();
void setSteps(ArrayRef<int64_t> newSteps);

static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }
static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; }
static StringRef getStepsAttrName() { return "steps"; }
}];
}

def AffinePrefetchOp : Affine_Op<"prefetch"> {
let summary = "affine prefetch operation";
let description = [{
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/IR/OpImplementation.h
Expand Up @@ -105,7 +105,8 @@ class OpAsmPrinter {
if (types.begin() != types.end())
printArrowTypeList(types);
}
template <typename TypeRange> void printArrowTypeList(TypeRange &&types) {
template <typename TypeRange>
void printArrowTypeList(TypeRange &&types) {
auto &os = getStream() << " -> ";

bool wrapped = !has_single_element(types) ||
Expand Down Expand Up @@ -517,7 +518,8 @@ class OpAsmParser {
virtual ParseResult
parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
SmallVectorImpl<NamedAttribute> &attrs,
Delimiter delimiter = Delimiter::Square) = 0;

//===--------------------------------------------------------------------===//
// Region Parsing
Expand Down Expand Up @@ -579,7 +581,8 @@ class OpAsmParser {
virtual ParseResult parseType(Type &result) = 0;

/// Parse a type of a specific type.
template <typename TypeT> ParseResult parseType(TypeT &result) {
template <typename TypeT>
ParseResult parseType(TypeT &result) {
llvm::SMLoc loc = getCurrentLocation();

// Parse any kind of type.
Expand Down Expand Up @@ -614,7 +617,8 @@ class OpAsmParser {
virtual ParseResult parseColonType(Type &result) = 0;

/// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> ParseResult parseColonType(TypeType &result) {
template <typename TypeType>
ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc = getCurrentLocation();

// Parse any kind of type.
Expand Down
236 changes: 234 additions & 2 deletions mlir/lib/Dialect/AffineOps/AffineOps.cpp
Expand Up @@ -134,9 +134,11 @@ bool mlir::isValidDim(Value value) {
return isTopLevelValue(dimOp.getOperand());
return false;
}
// This value has to be a block argument for a FuncOp or an affine.for.
// This value has to be a block argument of a FuncOp, an 'affine.for', or an
// 'affine.parallel'.
auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp) ||
isa<AffineParallelOp>(parentOp);
}

/// Returns true if the 'index' dimension of the `memref` defined by
Expand Down Expand Up @@ -2150,6 +2152,236 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this);
}

//===----------------------------------------------------------------------===//
// AffineParallelOp
//===----------------------------------------------------------------------===//

void AffineParallelOp::build(Builder *builder, OperationState &result,
ArrayRef<int64_t> ranges) {
// Default initalize empty maps.
auto lbMap = AffineMap::get(builder->getContext());
auto ubMap = AffineMap::get(builder->getContext());
// If there are ranges, set each to [0, N).
if (ranges.size()) {
SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
builder->getAffineConstantExpr(0));
lbMap = AffineMap::get(0, 0, lbExprs);
SmallVector<AffineExpr, 8> ubExprs;
for (int64_t range : ranges)
ubExprs.push_back(builder->getAffineConstantExpr(range));
ubMap = AffineMap::get(0, 0, ubExprs);
}
build(builder, result, lbMap, {}, ubMap, {});
}

void AffineParallelOp::build(Builder *builder, OperationState &result,
AffineMap lbMap, ValueRange lbArgs,
AffineMap ubMap, ValueRange ubArgs) {
auto numDims = lbMap.getNumResults();
// Verify that the dimensionality of both maps are the same.
assert(numDims == ubMap.getNumResults() &&
"num dims and num results mismatch");
// Make default step sizes of 1.
SmallVector<int64_t, 8> steps(numDims, 1);
build(builder, result, lbMap, lbArgs, ubMap, ubArgs, steps);
}

void AffineParallelOp::build(Builder *builder, OperationState &result,
AffineMap lbMap, ValueRange lbArgs,
AffineMap ubMap, ValueRange ubArgs,
ArrayRef<int64_t> steps) {
auto numDims = lbMap.getNumResults();
// Verify that the dimensionality of the maps matches the number of steps.
assert(numDims == ubMap.getNumResults() &&
"num dims and num results mismatch");
assert(numDims == steps.size() && "num dims and num steps mismatch");
result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
result.addAttribute(getStepsAttrName(), builder->getI64ArrayAttr(steps));
result.addOperands(lbArgs);
result.addOperands(ubArgs);
// Create a region and a block for the body.
auto bodyRegion = result.addRegion();
auto body = new Block();
// Add all the block arguments.
for (unsigned i = 0; i < numDims; ++i)
body->addArgument(IndexType::get(builder->getContext()));
bodyRegion->push_back(body);
ensureTerminator(*bodyRegion, *builder, result.location);
}

unsigned AffineParallelOp::getNumDims() { return steps().size(); }

AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
return getOperands().take_front(lowerBoundsMap().getNumInputs());
}

AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
return getOperands().drop_front(lowerBoundsMap().getNumInputs());
}

AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
}

AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
}

AffineValueMap AffineParallelOp::getRangesValueMap() {
AffineValueMap out;
AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
&out);
return out;
}

Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
// Try to convert all the ranges to constant expressions.
SmallVector<int64_t, 8> out;
AffineValueMap rangesValueMap = getRangesValueMap();
out.reserve(rangesValueMap.getNumResults());
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
auto expr = rangesValueMap.getResult(i);
auto cst = expr.dyn_cast<AffineConstantExpr>();
if (!cst)
return llvm::None;
out.push_back(cst.getValue());
}
return out;
}

Block *AffineParallelOp::getBody() { return &region().front(); }

OpBuilder AffineParallelOp::getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}

void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
assert(newSteps.size() == getNumDims() && "steps & num dims mismatch");
setAttr(getStepsAttrName(), getBodyBuilder().getI64ArrayAttr(newSteps));
}

static LogicalResult verify(AffineParallelOp op) {
auto numDims = op.getNumDims();
if (op.lowerBoundsMap().getNumResults() != numDims ||
op.upperBoundsMap().getNumResults() != numDims ||
op.steps().size() != numDims ||
op.getBody()->getNumArguments() != numDims) {
return op.emitOpError("region argument count and num results of upper "
"bounds, lower bounds, and steps must all match");
}
// Verify that the bound operands are valid dimension/symbols.
/// Lower bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
op.lowerBoundsMap().getNumDims())))
return failure();
/// Upper bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
op.upperBoundsMap().getNumDims())))
return failure();
return success();
}

static void print(OpAsmPrinter &p, AffineParallelOp op) {
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
op.getLowerBoundsOperands());
p << ") to (";
p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
op.getUpperBoundsOperands());
p << ')';
SmallVector<int64_t, 4> steps;
bool elideSteps = true;
for (auto attr : op.steps()) {
auto step = attr.cast<IntegerAttr>().getInt();
elideSteps &= (step == 1);
steps.push_back(step);
}
if (!elideSteps) {
p << " step (";
interleaveComma(steps, p);
p << ')';
}
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{AffineParallelOp::getLowerBoundsMapAttrName(),
AffineParallelOp::getUpperBoundsMapAttrName(),
AffineParallelOp::getStepsAttrName()});
}

//
// operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
// `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
// steps ::= `steps` `(` integer-literals `)`
//
static ParseResult parseAffineParallelOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
SmallVector<OpAsmParser::OperandType, 4> ivs;
SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser.parseEqual() ||
parser.parseAffineMapOfSSAIds(
lowerBoundsMapOperands, lowerBoundsAttr,
AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lowerBoundsMapOperands, indexType,
result.operands) ||
parser.parseKeyword("to") ||
parser.parseAffineMapOfSSAIds(
upperBoundsMapOperands, upperBoundsAttr,
AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upperBoundsMapOperands, indexType,
result.operands))
return failure();

AffineMapAttr stepsMapAttr;
SmallVector<NamedAttribute, 1> stepsAttrs;
SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
if (failed(parser.parseOptionalKeyword("step"))) {
SmallVector<int64_t, 4> steps(ivs.size(), 1);
result.addAttribute(AffineParallelOp::getStepsAttrName(),
builder.getI64ArrayAttr(steps));
} else {
if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
AffineParallelOp::getStepsAttrName(),
stepsAttrs,
OpAsmParser::Delimiter::Paren))
return failure();

// Convert steps from an AffineMap into an I64ArrayAttr.
SmallVector<int64_t, 4> steps;
auto stepsMap = stepsMapAttr.getValue();
for (const auto &result : stepsMap.getResults()) {
auto constExpr = result.dyn_cast<AffineConstantExpr>();
if (!constExpr)
return parser.emitError(parser.getNameLoc(),
"steps must be constant integers");
steps.push_back(constExpr.getValue());
}
result.addAttribute(AffineParallelOp::getStepsAttrName(),
builder.getI64ArrayAttr(steps));
}

// Now parse the body.
Region *body = result.addRegion();
SmallVector<Type, 4> types(ivs.size(), indexType);
if (parser.parseRegion(*body, ivs, types) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();

// Add a terminator if none was parsed.
AffineParallelOp::ensureTerminator(*body, builder, result.location);
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit fdc7a16

Please sign in to comment.