Skip to content

Commit

Permalink
[mlir] Introduce linalg.tiled_yield terminator for `linalg.tiled_lo…
Browse files Browse the repository at this point in the history
  • Loading branch information
pifon2a committed Jul 19, 2021
1 parent aa69f0d commit 3b03d9b
Show file tree
Hide file tree
Showing 14 changed files with 271 additions and 128 deletions.
34 changes: 27 additions & 7 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Expand Up @@ -492,7 +492,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
RecursiveSideEffects,
SingleBlockImplicitTerminator<"linalg::YieldOp">
SingleBlockImplicitTerminator<"linalg::TiledYieldOp">
]> {
let summary = "Linalg tiled loop operation";
let description = [{
Expand All @@ -509,7 +509,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
every tensor argument of TiledLoopOp.

The body region must contain exactly one block that terminates with
`linalg.yield` with the operands resulting from `insert_slice` operations.
`linalg.tiled_yield`.

Example:

Expand All @@ -528,9 +528,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [

%result_sub = linalg.generic ...

%result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1]
: tensor<?x?xi8> into tensor<24x64xi8>
linalg.yield %result : tensor<24x64xi8>
linalg.tiled_yield %result_sub to %out_sub : tensor<?x?xi8>
}
```

Expand All @@ -540,7 +538,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
every memref argument of TiledLoopOp.

The body region must contain exactly one block that terminates with
`linalg.yield` with no operands.
`linalg.tiled_yield` with no operands.

Example:

Expand All @@ -558,7 +556,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
: memref<24x64xi8> to memref<?x?xi8>

%result_sub = linalg.generic ...
linalg.yield
linalg.tiled_yield
}
```
}];
Expand Down Expand Up @@ -747,6 +745,28 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let hasFolder = 1;
}

def Linalg_TiledYieldOp : Linalg_Op<"tiled_yield",
[NoSideEffect, ReturnLike, Terminator, SameVariadicOperandSize]>,
Arguments<(ins Variadic<AnyType>:$tiles, Variadic<AnyType>:$outputs)> {
let summary = "Linalg tiled yield operation";
let description = [{
`linalg.tiled_yield` is a special terminator operation for the block inside
the region of `linalg.tiled_loop` op. It updates the part of the enclosing
`linalg.tiled_loop` result specifies by the `outputs` operand with the
values from the `tiles` operand.

Example:

```mlir
linalg.tiled_loop ... outs(%out_ = %out : tensor<?x?xf32>) {
%output = tensor.extract_slice %out_... // or %output = %out_
%tile = "some_computation"
linalg.tiled_yield %tile in %output : tensor<?x?xf32>
```
}];
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
}

def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>,
Arguments<(ins Confined<I64Attr, [IntMinValue<0>]>:$dim)>,
Results<(outs Index:$result)> {
Expand Down
133 changes: 100 additions & 33 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -1497,30 +1497,6 @@ static LogicalResult verify(linalg::YieldOp op) {
return success();
}

if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
// Check if output args with tensor types match results types.
SmallVector<Value, 2> tensorOuts;
llvm::copy_if(
tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
[&](Value out) { return out.getType().isa<RankedTensorType>(); });
if (tensorOuts.size() != op.values().size())
return op.emitOpError("expected number of tensor output args = ")
<< tensorOuts.size() << " to match the number of yield operands = "
<< op.values().size();

TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
for (auto &item :
llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
Type outType, resultType;
unsigned index = item.index();
std::tie(outType, resultType) = item.value();
if (outType != resultType)
return op.emitOpError("expected yield operand ")
<< index << " with type = " << resultType
<< " to match output arg type = " << outType;
}
return success();
}
return op.emitOpError("expected parent op with LinalgOp interface");
}

Expand Down Expand Up @@ -1892,11 +1868,11 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
return failure();

Block *block = tiledLoop.getBody();
auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
auto yieldOp = cast<linalg::TiledYieldOp>(block->getTerminator());

// Match the pattern and collect output buffers that will replace the output
// tensors and also the ops that will be ignored when cloning the body.
SmallVector<Value, 2> newOutputOperands, newYieldArgs;
SmallVector<Value, 2> newOutputOperands, newYieldTileArgs, newYieldOutArgs;
int resultId = 0;
// Store ids of the corresponding old and new output operands.
SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
Expand All @@ -1917,13 +1893,15 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
continue;
}
Value result = tiledLoop.getResult(resultId);
Value yieldArg = yieldOp.getOperand(resultId);
if (yieldArg != outRegionArg || !result.use_empty()) {
Value yieldTileArg = yieldOp.tiles()[resultId];
Value yieldOutArg = yieldOp.outputs()[resultId];
if (yieldTileArg != outRegionArg || !result.use_empty()) {
oldOutputIdToNew[index] = newOutputOperands.size();
oldResultIdToNew[resultId] = newYieldArgs.size();
oldResultIdToNew[resultId] = newYieldTileArgs.size();
resultReplacement[resultId] = out;
newOutputOperands.push_back(out);
newYieldArgs.push_back(yieldArg);
newYieldTileArgs.push_back(yieldTileArg);
newYieldOutArgs.push_back(yieldOutArg);
}
++resultId;
}
Expand Down Expand Up @@ -1952,9 +1930,12 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
for (auto &op : tiledLoop.getBody()->without_terminator())
innerBuilder.clone(op, bvm);
innerBuilder.create<linalg::YieldOp>(
loc, llvm::to_vector<2>(llvm::map_range(
newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
innerBuilder.create<linalg::TiledYieldOp>(
loc,
llvm::to_vector<2>(llvm::map_range(
newYieldTileArgs, [&](Value arg) { return bvm.lookup(arg); })),
llvm::to_vector<2>(llvm::map_range(
newYieldOutArgs, [&](Value arg) { return bvm.lookup(arg); })));

for (const auto &en : llvm::enumerate(oldResultIdToNew))
if (en.value() != kNoMatch)
Expand All @@ -1976,6 +1957,92 @@ LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
return foldMemRefCastInTiledLoopOp(*this);
}

//===----------------------------------------------------------------------===//
// TiledYieldOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, TiledYieldOp op) {
p << op.getOperationName();

if (!op.tiles().empty()) {
llvm::interleaveComma(llvm::zip(op.tiles(), op.outputs()), p, [&](auto it) {
p << ' ' << std::get<0>(it) << " in " << std::get<1>(it) << " : "
<< std::get<1>(it).getType();
});
}
p.printOptionalAttrDict(op->getAttrs());
}

static ParseResult parseTiledYieldOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> tiles, outputs;
SmallVector<Type, 4> types;

OpAsmParser::OperandType tile;
while (parser.parseOptionalOperand(tile).hasValue()) {
Type type;
OpAsmParser::OperandType output;
if (parser.parseKeyword("in") || parser.parseOperand(output) ||
parser.parseColon() || parser.parseType(type))
return failure();
tiles.push_back(tile);
outputs.push_back(output);
types.push_back(type);
parser.parseOptionalComma();
}
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.resolveOperands(tiles, types, loc, result.operands) ||
parser.resolveOperands(outputs, types, loc, result.operands))
return failure();

// Parse optional attributes.
parser.parseOptionalAttrDict(result.attributes);

return success();
}

static LogicalResult verify(TiledYieldOp op) {
// Check if output args with tensor types match results types.
auto loop = op->getParentOfType<TiledLoopOp>();
SmallVector<Value, 2> loopTensorOuts;
llvm::copy_if(
loop.outputs(), std::back_inserter(loopTensorOuts),
[&](Value out) { return out.getType().isa<RankedTensorType>(); });
if (loopTensorOuts.size() != op.tiles().size())
return op.emitOpError("expected number of tensor output args = ")
<< loopTensorOuts.size()
<< " to match the number of yield operands = " << op.tiles().size();

// Check if the `tiles` args types match the `outputs` args types.
SmallVector<Value, 2> loopTensorOutsBlockArgs;
llvm::copy_if(
loop.getRegionOutputArgs(), std::back_inserter(loopTensorOutsBlockArgs),
[&](Value out) { return out.getType().isa<RankedTensorType>(); });
for (auto en : llvm::enumerate(
llvm::zip(op.tiles(), op.outputs(), loopTensorOutsBlockArgs))) {
size_t index = en.index();
Type tileType = std::get<0>(en.value()).getType();
Value yieldOut = std::get<1>(en.value());
Type yieldOutType = yieldOut.getType();

if (tileType != yieldOutType)
return op.emitOpError("expected tile operand with type = ")
<< tileType << " to match output type = " << yieldOutType;

// Check if yieldOut is either an output bbArg or a slice of it.
Value src = yieldOut;
if (auto extractSlice = llvm::dyn_cast_or_null<tensor::ExtractSliceOp>(
yieldOut.getDefiningOp()))
src = extractSlice.source();

Value loopBlockArg = std::get<2>(en.value());
if (src != loopBlockArg)
return op.emitOpError("expected output ")
<< index << " to be a subset of the corresponding block argument";
}
return success();
}

//===----------------------------------------------------------------------===//
// IndexOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 19 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Expand Up @@ -372,6 +372,7 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
ReturnOp,
TiledLoopOp,
VectorTransferOpInterface,
linalg::TiledYieldOp,
linalg::YieldOp,
scf::YieldOp>(op)
// clang-format on
Expand Down Expand Up @@ -519,7 +520,7 @@ static Optional<OpResult> getAliasingOpResult(OpOperand &opOperand) {
return None;
return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
// These terminators legitimately have no result.
.Case<ReturnOp, linalg::YieldOp, scf::YieldOp>(
.Case<ReturnOp, linalg::TiledYieldOp, linalg::YieldOp, scf::YieldOp>(
[&](auto op) { return OpResult(); })
// ConstantOp is never inplaceable.
.Case([&](ConstantOp op) { return op->getResult(0); })
Expand Down Expand Up @@ -570,6 +571,11 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
if (auto linalgOp = dyn_cast<LinalgOp>(opOperand.getOwner()))
return linalgOp.isInputTensor(&opOperand) ||
linalgOp.isInitTensor(&opOperand);
// This is questionable. Should we consider TiledYieldOp as an op that
// bufferizes to "read" for the `tile` args and to "write" for the `output`
// args?
if (isa<TiledYieldOp>(opOperand.getOwner()))
return false;
// All other cases are considered to bufferize to memory reads.
// In particular, terminators are often the last use and need to be considered
// as reads to return the proper value and avoid WAW clobbers.
Expand All @@ -583,7 +589,8 @@ static bool
bufferizesToMemoryWrite(OpOperand &opOperand,
InPlaceSpec inPlaceSpec = InPlaceSpec::None) {
// These terminators are not writes.
if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
if (isa<ReturnOp, linalg::TiledYieldOp, linalg::YieldOp, scf::YieldOp>(
opOperand.getOwner()))
return false;
// ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
// may.
Expand Down Expand Up @@ -2110,9 +2117,6 @@ static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp,
// No tensors -> success.
if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor))
return success();
// linalg::YieldOp nested under TiledLoop must just canonicalize.
if (yieldOp->getParentOfType<TiledLoopOp>())
return success();
llvm_unreachable("unexpected yieldOp");
}

Expand All @@ -2131,6 +2135,15 @@ static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp,
extractOp.replaceAllUsesWith(l);
return success();
}

/// Bufferization for linalg::TiledYieldOp just results in later
/// canonicalization.
static LogicalResult bufferize(OpBuilder &b, linalg::TiledYieldOp yieldOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
return success();
}

//===----------------------------------------------------------------------===//
// Bufferization analyses.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2332,6 +2345,7 @@ static LogicalResult bufferizeFuncOpInternals(
TiledLoopOp,
VectorTransferOpInterface,
linalg::YieldOp,
linalg::TiledYieldOp,
scf::YieldOp>([&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
Expand Down

0 comments on commit 3b03d9b

Please sign in to comment.