Skip to content

Commit

Permalink
[mlir] Add a builder to linalg.tiled_loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
pifon2a committed Feb 24, 2021
1 parent 6201017 commit 7377ef9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Expand Up @@ -484,6 +484,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
linalg.yield %f0, %f1 : f32, f32
```
}];
let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>];
}

def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
Expand Down Expand Up @@ -537,6 +538,21 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
ArrayAttr:$iterator_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region);

let builders = [
OpBuilderDAG<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
"nullptr">:$bodyBuilderFn)>,
];

let extraClassDeclaration = [{
ValueRange getInductionVars() {
return getBody()->getArguments();
}
unsigned getNumLoops() { return step().size(); }
}];
}


Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -1701,6 +1701,40 @@ static LogicalResult verify(linalg::YieldOp op) {
// TiledLoopOp
//===----------------------------------------------------------------------===//

void TiledLoopOp::build(
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
result.addOperands(inputs);
result.addOperands(outputs);
result.addAttribute(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
static_cast<int32_t>(upperBounds.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
result.addAttribute(getIteratorTypesAttrName(),
builder.getStrArrayAttr(iteratorTypes));
result.addTypes(outputs.getTypes());

OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
Region *bodyRegion = result.addRegion();
Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);

if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock);
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
}
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}

static void print(OpAsmPrinter &p, TiledLoopOp op) {
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
<< op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
Expand Down

0 comments on commit 7377ef9

Please sign in to comment.