Skip to content

Commit

Permalink
[mlir][interfaces] Drop dest/tileDestOperands from TilingInterface
Browse files Browse the repository at this point in the history
`getTiledImplementation`/`generateResultTileValue` only computes the tiled operation, but does not insert the result into any tensor.

Differential Revision: https://reviews.llvm.org/D133015
  • Loading branch information
matthias-springer committed Sep 1, 2022
1 parent 7c0cf32 commit 5479428
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 44 deletions.
22 changes: 2 additions & 20 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,12 @@ def TilingInterface : OpInterface<"TilingInterface"> {
`getIterationDomain`. The caller provides the information of the
tile within this iteration space whose implementation the
caller needs.
- `dest` are the Value into which the result of the tiled
operation is to be inserted into. The type of the `dest`
Values is same as the types returned by
`getDestinationOperands` method.
- `offsets` provides the offset of the tile in the coordinate system
of the original iteration space, i.e., if an iteration space
dimension had non-zero offset, it must be included in the offset
provided here (as opposed to zero-based offset "relative" to the
iteration space).
- `sizes` provides the size of the tile.
- `tileDestOperands` specifies whether to also tile `dest` operands
or not. Avoiding tiling `dest` operands can be useful for
composition with various looping container ops.

The method returns the operation that is the tiled
implementation.
Expand All @@ -93,10 +86,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
"ValueRange ":$dest,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"bool ":$tileDestOperands),
"ArrayRef<OpFoldResult> ":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
Expand Down Expand Up @@ -140,29 +131,20 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tiled to generate the result tile. In practical terms this
implies it cannot be tiled and fused with its consumers.

- `dest` are the Value into which the result of the tiled
operation is to be inserted into. The type of the `dest`
Values is same as the types returned by
`getDestinationOperands` method.
- `offsets` provides the offset of the tile in the coordinate system
of the original iteration space, i.e., if an iteration space
dimension had non-zero offset, it must be included in the offset
provided here (as opposed to zero-based offset "relative" to the
iteration space).
- `sizes` provides the size of the tile.
- `tileDestOperands` specifies whether to also tile `dest` operands
or not. Avoiding tiling `dest` operands can be useful for
composition with various looping container ops.
}],
/*retType=*/"FailureOr<Value>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ValueRange":$dest,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes,
"bool":$tileDestOperands),
"ArrayRef<OpFoldResult>":$sizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,6 @@ static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
if (sliceOps.empty())
return failure();

SmallVector<Value> destinationOperands =
tileableProducer.getDestinationOperands(rewriter);

// Try to fuse the producer in-place.
SmallVector<Operation *> fusedOps;
for (tensor::ExtractSliceOp sliceOp : sliceOps) {
Expand All @@ -253,8 +250,8 @@ static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,

// Tile the producer.
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, /*resultNumber=*/0, destinationOperands,
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledProducer))
return failure();
fusedOps.push_back(tiledProducer->getDefiningOp());
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,

// Create the part as it it were a single tile.
SmallVector<Operation *> tiled =
op.getTiledImplementation(b, resultOperands, offsetsCopy, sizesCopy,
/*tileDestOperands=*/true);
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
assert(tiled.size() == 1 && "expected a single result from tiling");
auto part = cast<TilingInterface>(tiled.front());

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
}

SmallVector<Operation *> tiledOps =
op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
/*tileDestOperands=*/true);
op.getTiledImplementation(b, tiledOffsets, tiledSizes);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();

Expand Down
11 changes: 4 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,9 @@ struct LinalgOpTilingInterface

// Instantiate the tiled implementation of the operation.
SmallVector<Operation *>
getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool tileDestOperands) const {
ArrayRef<OpFoldResult> sizes) const {
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
// specified could lead to out of bounds accesses.
Location loc = op->getLoc();
Expand Down Expand Up @@ -172,10 +171,8 @@ struct LinalgOpTilingInterface

FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
unsigned resultNumber,
ValueRange dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool tileDestOperands) const {
ArrayRef<OpFoldResult> sizes) const {
auto linalgOp = cast<LinalgOp>(op);

// Check that the indexing map used for the output is a projected
Expand Down Expand Up @@ -210,7 +207,7 @@ struct LinalgOpTilingInterface
}

SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
b, iterationTileOffsets, iterationTileSizes);
if (tiledOp.size() != 1)
return op->emitOpError("failed to generate tiled implementation");

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
SmallVector<Operation *> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
if (tiledImplementation.size() != 1) {
return rewriter.notifyMatchFailure(
op, "expected tiled implementation to return a single op");
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
}

SmallVector<Operation *>
getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool /*tileDestOperands*/) const {
ArrayRef<OpFoldResult> sizes) const {
Operation *result =
tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
if (!result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
return failure();

FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
builder, producer.getResultNumber(),
producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), true);
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledResult))
return failure();

Expand Down

0 comments on commit 5479428

Please sign in to comment.