Skip to content

Commit

Permalink
[mlir][TilingInterface] Modify TilingInterface methods to better re…
Browse files Browse the repository at this point in the history
…turn the state of the transformed IR.

Currently the `getTiledImplementation` and `generateResultTileValue`
return just `SmallVector<Operation *>` and `FailureOr<Value>`.

- For `getTiledImplementation` returning empty implies tiling wasnt
  done. There is also an implicit assumption that the tiled operation
  results correspond to the tiled values of the result of the original
  operation. This cannot handle cases where the tiled implementation
  might use multiple operations to compute the tiled value for the
  results of the untiled operation. Sometimes, the tiled operation
  might not directly give the tiled values, and might require casts,
  etc to get a replacement.
- For `generateResultTileValue`, it is assumed that the op defining
  the returned `Value` is the operation that represents the tiled
  computation. Again presence of casts, etc violate this.

Instead make these methods return
```
struct TilingResult {
  SmallVector<Operation *> tiledOps;
  SmallVector<Value> tiledValues;
};
```

The `tiledOps` represent the operations generated that are relevant
for subsequent transformations. The `tiledValues` represent the tiled
values for the results of the original operation. This better
transmits the state of the transformed IR.

As a consequence the following methods also return `FailureOr<TilingResult>`
- `tensor::replaceExtractSliceWithTiledProducer`
- `tensor::bubbleUpPadSlice`

Differential Revision: https://reviews.llvm.org/D145133
  • Loading branch information
Mahesh Ravishankar committed Mar 16, 2023
1 parent a586c55 commit 809e3d8
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 122 deletions.
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
Expand Up @@ -16,6 +16,9 @@
#include "mlir/IR/Dialect.h"

namespace mlir {

struct TilingResult;

namespace tensor {

class PadOp;
Expand All @@ -39,10 +42,10 @@ class PadOp;
/// to guard against the case that we might take a zero-sized slice from the
/// original source. For such cases, we `tensor.generate` to generate the
/// full tensor.
Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard = true);
FailureOr<TilingResult> bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard = true);

/// Registers external models for Tiling interface for tensor ops.
/// Currently, it registers:
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Expand Up @@ -13,6 +13,9 @@
#include "mlir/IR/PatternMatch.h"

namespace mlir {

struct TilingResult;

namespace tensor {

/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
Expand All @@ -26,7 +29,7 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
/// use cases can inherit from this pattern and add necessary controls.
FailureOr<Value> replaceExtractSliceWithTiledProducer(
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);

/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.h
Expand Up @@ -21,6 +21,20 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"

namespace mlir {

/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
SmallVector<Operation *> tiledOps;
SmallVector<Value> tiledValues;
};

} // namespace mlir

/// Include the ODS generated interface header files.
#include "mlir/Interfaces/TilingInterface.h.inc"

Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Expand Up @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
/*retType=*/"SmallVector<Operation *>",
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
Expand Down Expand Up @@ -119,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
/*retType=*/"FailureOr<Value>",
/*retType=*/"FailureOr<TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
Expand Down
66 changes: 33 additions & 33 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -431,16 +431,15 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
static SmallVector<Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
return {};
}

// Search the producer slices accessed within the containing operation.
Expand All @@ -455,7 +454,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
if (it == tileableProducer->getUsers().end()) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find fusion opportunity for: " << *tileableProducer;
return nullptr;
return {};
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);

Expand All @@ -468,27 +467,29 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");

FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
FailureOr<TilingResult> tileAndFuseResult =
tileableProducer.generateResultTileValue(rewriter, resultNumber,
sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
return {};
}
for (auto tiledOp : tileAndFuseResult->tiledOps) {
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
}
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");

// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
.getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return fusedOp;
return tileAndFuseResult->tiledOps;
}

/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
Expand All @@ -497,7 +498,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
/// right before its "extract" use. The tiled op is fused under the
/// `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
static SmallVector<Operation *>
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
Expand All @@ -506,7 +508,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
return {};
}

// Search the first use by a "scf::ForallOp" user.
Expand All @@ -520,7 +522,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (!forallOp || forallOp != containingOp) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find a use by the containing op: " << *tileableProducer;
return nullptr;
return {};
}

// Search the producer slices accessed within the containing
Expand All @@ -542,7 +544,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
if (itBBArgUsers == bbArg.getUsers().end()) {
diag.attachNote(containingOp->getLoc())
<< "could not find fusion opportunity for bbArg: " << bbArg;
return nullptr;
return {};
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);

Expand All @@ -562,7 +564,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
destinationTensors))) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to get destination tensors for: " << *tileableProducer;
return nullptr;
return {};
}

IRMapping bvm;
Expand All @@ -573,21 +575,19 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });

// Tile the producer.
FailureOr<Value> tiledProducer =
FailureOr<TilingResult> tileAndFuseResult =
tileableProducerClone.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
return {};
}
LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");

// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
Expand All @@ -601,7 +601,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
destinationTensors.front());
});

return fusedOp;
return tileAndFuseResult->tiledOps;
}

static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Expand Down Expand Up @@ -714,21 +714,21 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
SmallVector<Operation *> tiledOps =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
if (!tiledOps.empty()) {
LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.push_back(tiled);
fusedOps.append(tiledOps);
continue;
}

Operation *tiledContainingOpOperand =
SmallVector<Operation *> tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
if (!tiledContainingOpOperand.empty()) {
LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
<< *containingOp);
fusedOps.push_back(tiledContainingOpOperand);
fusedOps.append(tiledContainingOpOperand);
continue;
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/Split.cpp
Expand Up @@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
offsetsCopy[dimension] = offset;

// Create the part as it it were a single tile.
SmallVector<Operation *> tiled =
FailureOr<TilingResult> tilingResult =
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
assert(tiled.size() == 1 && "expected a single result from tiling");
auto part = cast<TilingInterface>(tiled.front());

// Insert the results back and populate the `results` list.
for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
resultOffsets, resultSizes)))
return nullptr;
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
b.getIndexAttr(1));
Value inserted = b.create<tensor::InsertSliceOp>(
loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
loc, result, resultOperands[index], resultOffsets, resultSizes,
resultStrides);
results.push_back(inserted);
}

return part;
// TODO: this part can be generalized maybe to not expect a single op.
assert(tilingResult->tiledOps.size() == 1 &&
"expected split part to return a single tiled operation");
return cast<TilingInterface>(tilingResult->tiledOps[0]);
}

std::pair<TilingInterface, TilingInterface>
Expand Down
16 changes: 9 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Expand Up @@ -388,12 +388,13 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
}

// 4. Tile the cloned op and delete the clone.
SmallVector<Operation *> tiledOps =
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
tiledSizes);
b.eraseOp(clonedOp);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();
assert(tilingResult->tiledOps.size() == 1 &&
"expected a single produced tiled op");
tiledOp = tilingResult->tiledOps.front();
}

// 5. Parallel insert back into the result tensor.
Expand Down Expand Up @@ -729,12 +730,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(

// 5. Tile the cloned op and delete the clone.
if (tileSizes.empty()) {
SmallVector<Operation *> tiledOps =
FailureOr<TilingResult> tilingResult =
cast<TilingInterface>(clonedOp).getTiledImplementation(
b, tiledOffsets, tiledSizes);
assert(tiledOps.size() == 1 && "expected a single produced tiled op");
tiledOp = tiledOps.front();
tilingResults = tiledOp->getResults();
assert(tilingResult->tiledOps.size() == 1 &&
"expected a single produced tiled op");
tiledOp = tilingResult->tiledOps.front();
tilingResults = tilingResult->tiledValues;
} else {
LinalgTilingOptions options;
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
Expand Down
23 changes: 13 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Expand Up @@ -111,7 +111,7 @@ struct LinalgOpTilingInterface
}

// Instantiate the tiled implementation of the operation.
SmallVector<Operation *>
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
Expand All @@ -129,7 +129,7 @@ struct LinalgOpTilingInterface
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);

return {tiledOp};
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}

// Return the details of the output tile generated by the tiled
Expand Down Expand Up @@ -160,10 +160,10 @@ struct LinalgOpTilingInterface
return success();
}

FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
auto linalgOp = cast<LinalgOp>(op);

// Check that the indexing map used for the output is a projected
Expand Down Expand Up @@ -197,12 +197,15 @@ struct LinalgOpTilingInterface
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}

SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
b, iterationTileOffsets, iterationTileSizes);
if (tiledOp.size() != 1)
FailureOr<TilingResult> tilingResult =
tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
iterationTileSizes);
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");

return tiledOp[0]->getResult(resultNumber);
return TilingResult{
tilingResult->tiledOps,
SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
}

LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
Expand Down

0 comments on commit 809e3d8

Please sign in to comment.