From 809e3d8c98a80fc61c8bdbb3745d1d50a3f1d365 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Wed, 1 Mar 2023 16:33:14 -0800 Subject: [PATCH] [mlir][TilingInterface] Modify `TilingInterface` methods to better return the state of the transformed IR. Currently the `getTiledImplementation` and `generateResultTileValue` return just `SmallVector` and `FailureOr`. - For `getTiledImplementation` returning empty implies tiling wasnt done. There is also an implicit assumption that the tiled operation results correspond to the tiled values of the result of the original operation. This cannot handle cases where the tiled implementation might use multiple operations to compute the tiled value for the results of the untiled operation. Sometimes, the tiled operation might not directly give the tiled values, and might require casts, etc to get a replacement. - For `generateResultTileValue`, it is assumed that the op defining the returned `Value` is the operation that represents the tiled computation. Again presence of casts, etc violate this. Instead make these methods return ``` struct TilingResult { SmallVector tiledOps; SmallVector tiledValues; }; ``` The `tiledOps` represent the operations generated that are relevant for subsequent transformations. The `tiledValues` represent the tiled values for the results of the original operation. This better transmits the state of the transformed IR. As a consequence the following methods also return `FailureOr` - `tensor::replaceExtractSliceWithTiledProducer` - `tensor::bubbleUpPadSlice` Differential Revision: https://reviews.llvm.org/D145133 --- .../Tensor/IR/TensorTilingInterfaceImpl.h | 11 ++- .../Dialect/Tensor/Transforms/Transforms.h | 5 +- .../include/mlir/Interfaces/TilingInterface.h | 14 ++++ .../mlir/Interfaces/TilingInterface.td | 4 +- .../TransformOps/LinalgTransformOps.cpp | 66 +++++++-------- mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 16 ++-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 16 ++-- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 23 +++--- .../Dialect/Linalg/Transforms/Transforms.cpp | 6 +- .../SCF/Transforms/TileUsingInterface.cpp | 41 +++++----- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 80 +++++++++++-------- .../SwapExtractSliceWithProducerPatterns.cpp | 4 +- 12 files changed, 164 insertions(+), 122 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h index 30a5026cd68b3..7228a5a297ad8 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h @@ -16,6 +16,9 @@ #include "mlir/IR/Dialect.h" namespace mlir { + +struct TilingResult; + namespace tensor { class PadOp; @@ -39,10 +42,10 @@ class PadOp; /// to guard against the case that we might take a zero-sized slice from the /// original source. For such cases, we `tensor.generate` to generate the /// full tensor. -Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, - ArrayRef offsets, - ArrayRef sizes, - bool generateZeroSliceGuard = true); +FailureOr bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard = true); /// Registers external models for Tiling interface for tensor ops. /// Currently, it registers: diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 01985c943527c..4cdf360c51d72 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -13,6 +13,9 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { + +struct TilingResult; + namespace tensor { /// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op @@ -26,7 +29,7 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns, /// provide a mechanism to control where the application happens. With use of /// transform dialect that control is done within the transform dialect. Other /// use cases can inherit from this pattern and add necessary controls. -FailureOr replaceExtractSliceWithTiledProducer( +FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h index 99cbe21b178ca..ca570490ccf5b 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -21,6 +21,20 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" +namespace mlir { + +/// Container for result values of tiling. +/// - `tiledOps` contains operations created by the tiling implementation that +/// are returned to the caller for further transformations. +/// - `tiledValues` contains the tiled value corresponding to the result of the +/// untiled operation. +struct TilingResult { + SmallVector tiledOps; + SmallVector tiledValues; +}; + +} // namespace mlir + /// Include the ODS generated interface header files. #include "mlir/Interfaces/TilingInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 6cc0685bdae41..66382f29c2424 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { The method returns the operation that is the tiled implementation. }], - /*retType=*/"SmallVector", + /*retType=*/"FailureOr", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, @@ -119,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { iteration space). - `sizes` provides the size of the tile. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr", /*methodName=*/"generateResultTileValue", /*args=*/(ins "OpBuilder &":$b, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4aae6458ff128..4503d451a405c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -431,16 +431,15 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder, /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, - Diagnostic &diag, - Operation *producerOp, - Operation *containingOp) { +static SmallVector +tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the producer slices accessed within the containing operation. @@ -455,7 +454,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, if (it == tileableProducer->getUsers().end()) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; - return nullptr; + return {}; } auto sliceOpToTile = cast(*it); @@ -468,27 +467,29 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, sliceOpToTile.getSource().cast().getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); - FailureOr tiledProducer = tileableProducer.generateResultTileValue( - rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), - sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + FailureOr tileAndFuseResult = + tileableProducer.generateResultTileValue(rewriter, resultNumber, + sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; + } + for (auto tiledOp : tileAndFuseResult->tiledOps) { + LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast() .getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); - return fusedOp; + return tileAndFuseResult->tiledOps; } /// First, find the first "scf::ForallOp" user of `producerOp` and ensure @@ -497,7 +498,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, /// right before its "extract" use. The tiled op is fused under the /// `containingOp`. /// Return this fused op on success or nullptr if anything fails. -static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( +static SmallVector +tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); @@ -506,7 +508,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) << "producer is not a TileableInterface: " << *producerOp; - return nullptr; + return {}; } // Search the first use by a "scf::ForallOp" user. @@ -520,7 +522,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (!forallOp || forallOp != containingOp) { diag.attachNote(tileableProducer->getLoc()) << "could not find a use by the containing op: " << *tileableProducer; - return nullptr; + return {}; } // Search the producer slices accessed within the containing @@ -542,7 +544,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( if (itBBArgUsers == bbArg.getUsers().end()) { diag.attachNote(containingOp->getLoc()) << "could not find fusion opportunity for bbArg: " << bbArg; - return nullptr; + return {}; } auto sliceOpToTile = cast(*itBBArgUsers); @@ -562,7 +564,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( destinationTensors))) { diag.attachNote(tileableProducer->getLoc()) << "failed to get destination tensors for: " << *tileableProducer; - return nullptr; + return {}; } IRMapping bvm; @@ -573,21 +575,19 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); // Tile the producer. - FailureOr tiledProducer = + FailureOr tileAndFuseResult = tileableProducerClone.generateResultTileValue( rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); - if (failed(tiledProducer)) { + if (failed(tileAndFuseResult)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; - return nullptr; + return {}; } - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. - Operation *fusedOp = tiledProducer->getDefiningOp(); auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( - rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], sliceOpToTile->getResult(0) .getType() .cast() @@ -601,7 +601,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( destinationTensors.front()); }); - return fusedOp; + return tileAndFuseResult->tiledOps; } static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, @@ -714,21 +714,21 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, // cases, we can tile/clone once and reuse the value for each use. // Futhermore, producers should then be traversed according to a // topological sorting. - Operation *tiled = + SmallVector tiledOps = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); - if (tiled) { + if (!tiledOps.empty()) { LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); - fusedOps.push_back(tiled); + fusedOps.append(tiledOps); continue; } - Operation *tiledContainingOpOperand = + SmallVector tiledContainingOpOperand = tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); - if (tiledContainingOpOperand) { + if (!tiledContainingOpOperand.empty()) { LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" << *containingOp); - fusedOps.push_back(tiledContainingOpOperand); + fusedOps.append(tiledContainingOpOperand); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index c8c9c0bd4af89..e6fce56d4140b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -41,26 +41,26 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op, offsetsCopy[dimension] = offset; // Create the part as it it were a single tile. - SmallVector tiled = + FailureOr tilingResult = op.getTiledImplementation(b, offsetsCopy, sizesCopy); - assert(tiled.size() == 1 && "expected a single result from tiling"); - auto part = cast(tiled.front()); // Insert the results back and populate the `results` list. - for (auto i : llvm::seq(0, part->getNumResults())) { + for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { SmallVector resultOffsets, resultSizes; - if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, + if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, resultOffsets, resultSizes))) return nullptr; SmallVector resultStrides(resultOffsets.size(), b.getIndexAttr(1)); Value inserted = b.create( - loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, + loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); } - - return part; + // TODO: this part can be generalized maybe to not expect a single op. + assert(tilingResult->tiledOps.size() == 1 && + "expected split part to return a single tiled operation"); + return cast(tilingResult->tiledOps[0]); } std::pair diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 62eef97a17448..1e404cabbb518 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -388,12 +388,13 @@ static FailureOr tileToForallOpImpl( } // 4. Tile the cloned op and delete the clone. - SmallVector tiledOps = + FailureOr tilingResult = cast(clonedOp).getTiledImplementation(b, tiledOffsets, tiledSizes); b.eraseOp(clonedOp); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); } // 5. Parallel insert back into the result tensor. @@ -729,12 +730,13 @@ FailureOr linalg::tileReductionUsingForall( // 5. Tile the cloned op and delete the clone. if (tileSizes.empty()) { - SmallVector tiledOps = + FailureOr tilingResult = cast(clonedOp).getTiledImplementation( b, tiledOffsets, tiledSizes); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - tilingResults = tiledOp->getResults(); + assert(tilingResult->tiledOps.size() == 1 && + "expected a single produced tiled op"); + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; } else { LinalgTilingOptions options; FailureOr maybeTiled = tileLinalgOpImpl( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index cfc27ca44e421..676d6330cde3e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -111,7 +111,7 @@ struct LinalgOpTilingInterface } // Instantiate the tiled implementation of the operation. - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -129,7 +129,7 @@ struct LinalgOpTilingInterface Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); - return {tiledOp}; + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; } // Return the details of the output tile generated by the tiled @@ -160,10 +160,10 @@ struct LinalgOpTilingInterface return success(); } - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the output is a projected @@ -197,12 +197,15 @@ struct LinalgOpTilingInterface iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; } - SmallVector tiledOp = tilingInterfaceOp.getTiledImplementation( - b, iterationTileOffsets, iterationTileSizes); - if (tiledOp.size() != 1) + FailureOr tilingResult = + tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, + iterationTileSizes); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); - return tiledOp[0]->getResult(resultNumber); + return TilingResult{ + tilingResult->tiledOps, + SmallVector{tilingResult->tiledValues[resultNumber]}}; } LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 17c46182eb5d1..e001f59b21e93 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -952,12 +952,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( return failure(); } - Operation *tiledPadOp = + FailureOr tilingResult = tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), zeroSliceGuard); + if (failed(tilingResult)) + return failure(); // All shapes are static and the data source is actually used. Rewrite into // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); + rewriter.replaceOp(sliceOp, tilingResult->tiledValues); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 915e4b4ed1c56..6706f54662839 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -251,18 +251,20 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder, /// a destination passing style op. static SmallVector yieldTiledValues(RewriterBase &rewriter, ArrayRef initValues, - Operation *tiledOp, + TilingResult tilingResult, ArrayRef> tileOffsetsList, ArrayRef> tileSizesList, MutableArrayRef loops) { SmallVector replacements = - yieldTiledValues(rewriter, initValues, tiledOp->getResults(), + yieldTiledValues(rewriter, initValues, tilingResult.tiledValues, tileOffsetsList, tileSizesList, loops); - if (auto dstOp = dyn_cast(tiledOp)) { - auto innerMostLoop = loops.back(); - SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); - updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, - innerMostLoop.getRegionIterArgs()); + for (auto tiledOp : tilingResult.tiledOps) { + if (auto dstOp = dyn_cast(tiledOp)) { + auto innerMostLoop = loops.back(); + SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, + innerMostLoop.getRegionIterArgs()); + } } return replacements; } @@ -345,9 +347,9 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, if (!tilingResult.loops.empty()) rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); - SmallVector tiledImplementation = + FailureOr tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOps.append(tiledImplementation); + tilingResult.tiledOps.append(tiledImplementation->tiledOps); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -356,9 +358,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, // If loops are empty, the tiled op is used as the replacement for the untiled // op. if (tilingResult.loops.empty()) { - tilingResult.replacements = llvm::to_vector( - llvm::map_range(tiledImplementation[0]->getResults(), - [](OpResult result) -> Value { return result; })); + tilingResult.replacements = tiledImplementation->tiledValues; return tilingResult; } @@ -384,7 +384,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, return rewriter.notifyMatchFailure(op, "failed to get destinations"); tilingResult.replacements = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOps.back(), + rewriter, destinationTensors, tiledImplementation.value(), resultOffsetsList, resultSizesList, tilingResult.loops); LLVM_DEBUG({ @@ -523,12 +523,13 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, // 2. Generate the tiled implementation of the producer of the source OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(candidateSliceOp); - FailureOr fusedProducerValue = + FailureOr tileAndFuseResult = tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, fusableProducer); - if (failed(fusedProducerValue)) + if (failed(tileAndFuseResult)) return std::nullopt; - rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value()); + rewriter.replaceAllUsesWith(candidateSliceOp, + tileAndFuseResult->tiledValues[0]); // 3. If the slice is for a destination operand, for example, // @@ -592,8 +593,10 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, outerMostLoop.setIterArg(iterArgNumber.value(), dstOp.getTiedOpOperand(fusableProducer)->get()); } - if (auto dstOp = fusedProducerValue.value() - .getDefiningOp()) { + for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { + auto dstOp = dyn_cast(tileAndFusedOp); + if (!dstOp) + continue; scf::ForOp innerMostLoop = loops.back(); updateDestinationOperandsForTiledOp( rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), @@ -601,7 +604,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, } } return scf::SCFFuseProducerOfSliceResult{fusableProducer, - fusedProducerValue.value()}; + tileAndFuseResult->tiledValues[0]}; } /// Reconstruct the fused producer from within the tiled-and-fused code. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 1c4db01dc8f28..0faa29ade8047 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -46,15 +46,15 @@ struct PadOpTiling : public TilingInterface::ExternalModel { return loopRanges; } - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { - Operation *result = + FailureOr result = tensor::bubbleUpPadSlice(b, cast(op), offsets, sizes); - if (!result) - return {}; - return {result}; + if (failed(result)) + return failure(); + return result.value(); } LogicalResult @@ -117,7 +117,7 @@ struct PackOpTiling return getPackUnPackIterationDomain(cast(op), b); } - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -192,7 +192,8 @@ struct PackOpTiling Operation *tiledPackOp = b.create( loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); - return {tiledPackOp}; + return TilingResult{{tiledPackOp}, + SmallVector(tiledPackOp->getResults())}; } LogicalResult @@ -353,7 +354,7 @@ struct UnPackOpTiling /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we /// can get the actual result. - SmallVector + FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, ArrayRef sizes) const { @@ -412,12 +413,13 @@ struct UnPackOpTiling loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) - return {tiledUnpackOp}; + return TilingResult{{tiledUnpackOp}, + SmallVector(tiledUnpackOp->getResults())}; - Operation *extractSlice = + auto extractSlice = b.create(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - return {tiledUnpackOp, extractSlice}; + return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; } LogicalResult @@ -431,26 +433,29 @@ struct UnPackOpTiling return success(); } - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { - return getTiledImplementation(op, b, offsets, sizes) - .back() - ->getResult(resultNumber); + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + FailureOr tilingResult = + getTiledImplementation(op, b, offsets, sizes); + if (failed(tilingResult)) + return failure(); + return tilingResult.value(); } }; } // namespace -Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, - ArrayRef offsets, - ArrayRef sizes, - bool generateZeroSliceGuard) { +FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, + tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard) { // Only constant padding value supported. Value padValue = padOp.getConstantPaddingValue(); if (!padValue) - return nullptr; + return failure(); // Helper variables and functions for various arithmetic operations. These // are used extensively for computing new offset/length and padding values. @@ -584,10 +589,9 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, RankedTensorType::get(shape, padOp.getResultType().getElementType()); // Insert cast to ensure that types match. (May be folded away.) - auto castResult = [&](Operation *op) -> Operation * { - Value val = op->getResult(0); + auto castResult = [&](Value val) -> Value { if (resultType == val.getType()) - return op; + return val; return b.create(loc, resultType, val); }; @@ -601,7 +605,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, [&](OpBuilder &builder, Location gLoc, ValueRange indices) { builder.create(gLoc, padValue); }); - return castResult(generateOp); + return generateOp; }; // Emit a SliceOp and a PadOp. Should not be used in cases where @@ -617,30 +621,38 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); // Cast result and return. - return castResult(newPadOp); + return newPadOp; }; // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that // the original data source x is not used. - if (hasZeroLen) - return createGenerateOp(); + if (hasZeroLen) { + Operation *generateOp = createGenerateOp(); + return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; + } // If there are dynamic dimensions: Generate an scf.if check to avoid // creating SliceOps with result dimensions of size 0 at runtime. if (generateZeroSliceGuard && dynHasZeroLenCond) { + Operation *thenOp; + Operation *elseOp; auto result = b.create( loc, dynHasZeroLenCond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create(loc, createGenerateOp()->getResult(0)); + thenOp = createGenerateOp(); + b.create(loc, castResult(thenOp->getResult(0))); }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { - b.create(loc, createPadOfExtractSlice()->getResult(0)); + elseOp = createPadOfExtractSlice(); + b.create(loc, castResult(elseOp->getResult(0))); }); - return result; + return TilingResult{{result}, SmallVector(result->getResults())}; } - return createPadOfExtractSlice(); + + Operation *newPadOp = createPadOfExtractSlice(); + return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; } void mlir::tensor::registerTilingInterfaceExternalModels( diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 65176ed7b9e74..40d79c2053817 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -20,7 +20,7 @@ using namespace mlir; -FailureOr tensor::replaceExtractSliceWithTiledProducer( +FailureOr tensor::replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { auto producerOp = dyn_cast(producer.getOwner()); if (!producerOp) @@ -32,7 +32,7 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( })) return failure(); - FailureOr tiledResult = producerOp.generateResultTileValue( + FailureOr tiledResult = producerOp.generateResultTileValue( builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), sliceOp.getMixedSizes()); if (failed(tiledResult))