From 91261e5af8f626770079cc378d39e976467716fb Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 17 Sep 2025 21:47:46 -0700 Subject: [PATCH 1/2] [mlir][SCF] NFC refactor for better demarcation of splitting to use different loop types for tiling. Signed-off-by: MaheshRavishankar --- .../SCF/Transforms/TileUsingInterface.h | 34 +- .../SCF/Transforms/TileUsingInterface.cpp | 455 ++++++++++-------- 2 files changed, 279 insertions(+), 210 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 3205da6e448fc..117e1ce1371f2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction = /// Options to use to control tiling. struct SCFTilingOptions { + /// Specify which loop construct to use for tile and fuse. + enum class LoopType { ForOp, ForallOp}; + LoopType loopType = LoopType::ForOp; + SCFTilingOptions &setLoopType(LoopType type) { + loopType = type; + return *this; + } + /// Computation function that returns the tile sizes to use for each loop. /// Returning a tile size of zero implies no tiling for that loop. If the /// size of the returned vector is smaller than the number of loops, the inner @@ -50,6 +58,17 @@ struct SCFTilingOptions { /// proper interaction with folding. SCFTilingOptions &setTileSizes(ArrayRef tileSizes); + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; + SCFTilingOptions &setInterchange(ArrayRef interchange) { + interchangeVector = llvm::to_vector(interchange); + return *this; + } + + //-------------------------------------------------------------------------// + // Options related to tiling using `scf.forall`. + //-------------------------------------------------------------------------// + /// Computation function that returns the number of threads to use for /// each loop. Returning a num threads of zero implies no tiling for that /// loop. If the size of the returned vector is smaller than the number of @@ -70,21 +89,6 @@ struct SCFTilingOptions { /// function that computes num threads at the point they are needed. SCFTilingOptions &setNumThreads(ArrayRef numThreads); - /// The interchange vector to reorder the tiled loops. - SmallVector interchangeVector = {}; - SCFTilingOptions &setInterchange(ArrayRef interchange) { - interchangeVector = llvm::to_vector(interchange); - return *this; - } - - /// Specify which loop construct to use for tile and fuse. - enum class LoopType { ForOp, ForallOp }; - LoopType loopType = LoopType::ForOp; - SCFTilingOptions &setLoopType(LoopType type) { - loopType = type; - return *this; - } - /// Specify mapping of loops to devices. This is only respected when the loop /// constructs support such a mapping (like `scf.forall`). Will be ignored /// when using loop constructs that dont support such a mapping (like diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 834c02126fa53..b77f66b701927 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -155,18 +155,18 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, - ArrayRef tileSizes, + ArrayRef givenTileSizes, ArrayRef numThreads) { auto iterators = op.getLoopIteratorTypes(); - assert(iterators.size() == tileSizes.size() && + assert(iterators.size() == givenTileSizes.size() && "expected as many tile size values as number of loops"); assert((numThreads.empty() || (numThreads.size() == iterators.size())) && "when specified, expected number of threads to use for each loop"); bool isParallelTiling = false; - for (auto [index, iterator, tileSize] : - llvm::enumerate(iterators, tileSizes)) { - if (!isConstantIntValue(tileSize, 0)) { + for (auto [index, iterator, givenTileSize] : + llvm::enumerate(iterators, givenTileSizes)) { + if (!isConstantIntValue(givenTileSize, 0)) { isParallelTiling |= iterator == utils::IteratorType::parallel; } @@ -186,7 +186,7 @@ static LogicalResult checkTileSizes(TilingInterface op, } if (std::optional constTileSize = - getConstantIntValue(tileSize)) { + getConstantIntValue(givenTileSize)) { if (constTileSize.value() > 0 && iterator != utils::IteratorType::parallel) { op.emitWarning() << "tiling is not thread safe at axis #" << index; @@ -207,11 +207,11 @@ static LogicalResult checkTileSizes(TilingInterface op, /// Get the reduction dims that are tiled. This accounts for reduction dims /// that are specified as tiled, but the tile size is 0. static SetVector -getSanitizedReductionDims(ArrayRef tileSizes, +getSanitizedReductionDims(ArrayRef givenTileSizes, const scf::SCFTilingOptions &options) { SetVector reductionDims; for (auto dim : options.reductionDims) { - if (isConstantIntValue(tileSizes[dim], 0)) + if (isConstantIntValue(givenTileSizes[dim], 0)) continue; reductionDims.insert(dim); } @@ -236,14 +236,14 @@ static bool tileDividesIterationDomain(Range loopRange) { /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, - OpFoldResult tileSize) { - std::optional ts = getConstantIntValue(tileSize); + OpFoldResult givenTileSize) { + std::optional ts = getConstantIntValue(givenTileSize); if (ts && ts.value() == 1) - return tileSize; + return givenTileSize; if (tileDividesIterationDomain( - Range{loopRange.offset, loopRange.size, tileSize})) - return tileSize; + Range{loopRange.offset, loopRange.size, givenTileSize})) + return givenTileSize; // The tile size to use (to avoid out of bounds access) is minimum of // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled @@ -254,15 +254,15 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); return affine::makeComposedFoldedAffineMin( - b, loc, minMap, SmallVector{offset, size, tileSize}); + b, loc, minMap, SmallVector{offset, size, givenTileSize}); } /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less /// than `iterationSize`. -static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize) { - std::optional tileSizeConst = getConstantIntValue(tileSize); + std::optional tileSizeConst = getConstantIntValue(givenTileSize); std::optional numThreadsConst = getConstantIntValue(numThreads); std::optional iterSizeConst = getConstantIntValue(iterationSize); if (!tileSizeConst || !numThreadsConst || !iterSizeConst) @@ -274,114 +274,51 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to. static std::tuple, SmallVector> -getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, - ReductionTilingStrategy strategy, ValueRange ivs, +getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, - ArrayRef tileSizes, - ArrayRef numThreads, - const llvm::SetVector &reductionDims) { + ArrayRef givenTileSizes) { SmallVector offsets, sizes; int materializedLoopNum = 0; - - if (!numThreads.empty()) { - AffineExpr d0, d1, s0, s1; - AffineExpr offsetExpr, residualTileSizeExpr; - bindDims(rewriter.getContext(), d0, d1); - bindSymbols(rewriter.getContext(), s0, s1); - offsetExpr = d0 + d1 * s0; - residualTileSizeExpr = s1 - (d0 + d1 * s0); - - for (auto [index, nt, tileSize, loopRange] : - llvm::enumerate(numThreads, tileSizes, iterationDomain)) { - - // Non-tiled cases, set the offset and size to the - // `loopRange.offset/size`. - if (isZeroInteger(nt)) { - offsets.push_back(loopRange.offset); - sizes.push_back(loopRange.size); - continue; - } - - Value iv = ivs[materializedLoopNum++]; - OpFoldResult offset = affine::makeComposedFoldedAffineApply( - rewriter, loc, offsetExpr, - ArrayRef{loopRange.offset, iv, tileSize}); - OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( - rewriter, loc, residualTileSizeExpr, - {loopRange.offset, nt, tileSize, loopRange.size}); - - OpFoldResult size = tileSize; - if (!isZeroInteger(residualTileSize)) { - OpFoldResult sizeMinusOffsetPerThread = - affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, - {offset, loopRange.size}); - size = affine::makeComposedFoldedAffineMin( - rewriter, loc, - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), - {sizeMinusOffsetPerThread, tileSize}); - } - - // Consider the case where the original loop was `[0, 100)`. - // If number of threads are `7`, the tile size would be computed as - // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) - // - `offset = 0 + 6 * 15 = 105` - // - `tileSize = min(15, 100 - 105) = -5` - // To avoid negative tile sizes, we need to do a further - // `nonNegativeTileSize = affine.max(0, tileSize)`. - // This `max` can be avoided if - // `offset + tileSize * (numThreads - 1) < (ub - lb)` - if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { - AffineMap maxMap = - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); - size = affine::makeComposedFoldedAffineMax( - rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); - } - - offsets.push_back(offset); - sizes.push_back(size); + for (auto [givenTileSize, loopRange] : + llvm::zip_equal(givenTileSizes, iterationDomain)) { + + // Non-tiled cases, set the offset and size to the + // `loopRange.offset/size`. + if (isZeroInteger(givenTileSize)) { + offsets.push_back(loopRange.offset); + sizes.push_back(loopRange.size); + continue; } - return {offsets, sizes}; - } else { - for (auto [tileSize, loopRange] : - llvm::zip_equal(tileSizes, iterationDomain)) { - - // Non-tiled cases, set the offset and size to the - // `loopRange.offset/size`. - if (isZeroInteger(tileSize)) { - offsets.push_back(loopRange.offset); - sizes.push_back(loopRange.size); - continue; - } - Value iv = ivs[materializedLoopNum++]; - OpFoldResult offset = getAsOpFoldResult(iv); - offsets.push_back(offset); - OpFoldResult size = - getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); - sizes.push_back(size); - } - return {offsets, sizes}; + Value iv = ivs[materializedLoopNum++]; + OpFoldResult offset = getAsOpFoldResult(iv); + offsets.push_back(offset); + OpFoldResult size = + getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize); + sizes.push_back(size); } + return {offsets, sizes}; } /// Function to return the bounds of the loops to be generated. static std::tuple, SmallVector, SmallVector> getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes) { + ArrayRef givenTileSizes) { SmallVector lbs, ubs, steps; - for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { + for (auto [loopRange, givenTileSize] : + llvm::zip_equal(loopRanges, givenTileSizes)) { // No loop if the tile size is 0. - if (isZeroInteger(tileSize)) + if (isZeroInteger(givenTileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); - steps.push_back(tileSize); + steps.push_back(givenTileSize); } return {lbs, ubs, steps}; } -/// A function that allows returning additional yielded values during +/// Typedef for function that allows returning additional yielded values during /// `yieldTiledValuesAndReplace`. /// - `ivs` induction variable for the loop. /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. @@ -402,6 +339,30 @@ using YieldTiledValuesFn = std::function> &resultOffsets, SmallVector> &resultSizes)>; +/// Typedef for function that implements the body of a tiled loop. +/// - `ivs` induction variable for the loop. +/// - `tileOffsets` represents offsets for the tiled iteration space. +/// - `tileSizes` represents the sizes for the tiled iteraiton space. +/// - `outerDestinationTensors` tensor that holds the result. Is same size +/// as the destination operands of the original operations. +/// - `tiledResults` results of the tiled computation, corresponds to +/// tiles of the original operation computed by the loop body. +/// Should be same size as the `destinationTensors` +/// - `resultOffsets` is of the same size as `tiledResults` and represents +/// the offset to use when writing the corresponding element from +/// `tiledResults` into `destinationTensors`. +/// - `resultOffsets` is of the same size as `tiledResults` and represents +/// the size to use when writing the corresponding element from +/// `tiledResults` into `destinationTensors`. +/// In case the method needs to return `failure()` the method is expected +/// to clean up any inserted operations. +using GenerateTiledBodyFn = std::function tileOffsets, ArrayRef tileSizes, + ValueRange outerDestinationTensors, SmallVector &tiledResults, + SmallVector> &resultOffsets, + SmallVector> &resultSizes)>; + /// Clones the operation and updates the destination if the operation /// implements the `DestinationStyleOpInterface`. static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, @@ -417,26 +378,25 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, /// Generate the tile-loop nest using `scf.for` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. -/// - `yieldTiledValuesFn` is called to generated the loop body of the inner +/// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `outerDestinationTensors` are the init values to use for the outer most +/// loop. +/// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNestUsingForOp( +/// Returns the generated `scf.for` loops on success. +static FailureOr> generateLoopNestUsingForOp( RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes, ValueRange destinationTensors, - YieldTiledValuesFn yieldTiledValuesFn, - SmallVector &loops) { + ArrayRef givenTileSizes, ValueRange outerDestinationTensors, + GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); - assert(loopRanges.size() == tileSizes.size() && + assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = - getLoopBounds(rewriter, loc, loopRanges, tileSizes); + getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); SmallVector lbVals = getValueOrCreateConstantIndexOp(rewriter, loc, lbs); SmallVector ubVals = @@ -445,34 +405,42 @@ static LogicalResult generateLoopNestUsingForOp( getValueOrCreateConstantIndexOp(rewriter, loc, steps); SmallVector ivs; + SmallVector loops; + ValueRange innerDestinationTensors(outerDestinationTensors); for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { auto loop = - scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors, + scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors, [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) {}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPointToEnd(loop.getBody()); - destinationTensors = loop.getRegionIterArgs(); + innerDestinationTensors = loop.getRegionIterArgs(); } + // Compute the `offsets` and `sizes` to use for tiling. + SmallVector offsets, sizes; + std::tie(offsets, sizes) = + getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes); + SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, - tiledResults, resultOffsets, resultSizes))) { + if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, + innerDestinationTensors, tiledResults, resultOffsets, + resultSizes))) { return rewriter.notifyMatchFailure( loc, "failed to generate inner tile loop body"); } if (loops.empty()) - return success(); + return loops; - assert(tiledResults.size() == destinationTensors.size() && + assert(tiledResults.size() == innerDestinationTensors.size() && "Number of results of body should be equal to number of iter args"); // 6. Yield all the results of the tiled operation. SmallVector yieldedValues; for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : - llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, + llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); @@ -491,27 +459,108 @@ static LogicalResult generateLoopNestUsingForOp( cast(outerLoop.getOperation()).getBody()); scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults()); } - return success(); + return loops; +} + +/// Compute the `OpFoldResult`s that represents the multi-dimensional +/// `offset`s and `size`s of the tile of the iteration space that the +/// innermost loop body of the generated tiled loops corresponds to +/// when tiling using `forall` op. This is handle separately dut to +/// the special case handling needed for when the tiling is done by +/// specifying number of threads. +static std::tuple, SmallVector> +getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, + ValueRange ivs, + ArrayRef iterationDomain, + ArrayRef givenTileSizes, + ArrayRef numThreads) { + if (numThreads.empty()) { + return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain, + givenTileSizes); + } + + SmallVector offsets, sizes; + int materializedLoopNum = 0; + + AffineExpr d0, d1, s0, s1; + AffineExpr offsetExpr, residualTileSizeExpr; + bindDims(rewriter.getContext(), d0, d1); + bindSymbols(rewriter.getContext(), s0, s1); + offsetExpr = d0 + d1 * s0; + residualTileSizeExpr = s1 - (d0 + d1 * s0); + + for (auto [index, nt, givenTileSize, loopRange] : + llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) { + + // Non-tiled cases, set the offset and size to the + // `loopRange.offset/size`. + if (isZeroInteger(nt)) { + offsets.push_back(loopRange.offset); + sizes.push_back(loopRange.size); + continue; + } + + Value iv = ivs[materializedLoopNum++]; + OpFoldResult offset = affine::makeComposedFoldedAffineApply( + rewriter, loc, offsetExpr, + ArrayRef{loopRange.offset, iv, givenTileSize}); + OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( + rewriter, loc, residualTileSizeExpr, + {loopRange.offset, nt, givenTileSize, loopRange.size}); + + OpFoldResult size = givenTileSize; + if (!isZeroInteger(residualTileSize)) { + OpFoldResult sizeMinusOffsetPerThread = + affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, + {offset, loopRange.size}); + size = affine::makeComposedFoldedAffineMin( + rewriter, loc, + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + {sizeMinusOffsetPerThread, givenTileSize}); + } + + // Consider the case where the original loop was `[0, 100)`. + // If number of threads are `7`, the tile size would be computed as + // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) + // - `offset = 0 + 6 * 15 = 105` + // - `tileSize = min(15, 100 - 105) = -5` + // To avoid negative tile sizes, we need to do a further + // `nonNegativeTileSize = affine.max(0, tileSize)`. + // This `max` can be avoided if + // `offset + tileSize * (numThreads - 1) < (ub - lb)` + if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) { + AffineMap maxMap = + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + size = affine::makeComposedFoldedAffineMax( + rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); + } + + offsets.push_back(offset); + sizes.push_back(size); + } + return {offsets, sizes}; } /// Generate the tile-loop nest using `scf.forall` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `outerDestinationTensors` are the init values to use for the loop. /// - `mappingVector` is the mapping attributes to use for loop construction. /// Can be empty. -/// - `yieldTiledValuesFn` is called to generated the loop body of the inner +/// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNestUsingForallOp( - RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes, ArrayRef numThreads, - ArrayRef mappingVector, ValueRange destinationTensors, - YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { +/// Returns the generated `scf.forall` loop on success. +static FailureOr> +generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, + ArrayRef loopRanges, + ArrayRef givenTileSizes, + ArrayRef numThreads, + ArrayRef mappingVector, + ValueRange outerDestinationTensors, + GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); - assert(loopRanges.size() == tileSizes.size() && + assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); @@ -522,6 +571,7 @@ static LogicalResult generateLoopNestUsingForallOp( scf::ForallOp forallOp; bool useNumThreads = !numThreads.empty(); + SmallVector loops; if (useNumThreads) { // Prune the zero numthreads. SmallVector nonZeroNumThreads; @@ -531,29 +581,35 @@ static LogicalResult generateLoopNestUsingForallOp( nonZeroNumThreads.push_back(nt); } forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads, - destinationTensors, mappingAttr); + outerDestinationTensors, mappingAttr); } else { SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = - getLoopBounds(rewriter, loc, loopRanges, tileSizes); + getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, - destinationTensors, mappingAttr); + outerDestinationTensors, mappingAttr); } loops.push_back(forallOp); rewriter.setInsertionPoint(forallOp.getTerminator()); - destinationTensors = forallOp.getRegionOutArgs(); + ValueRange innerDestinationTensors = forallOp.getRegionOutArgs(); + SmallVector ivs = forallOp.getInductionVars(); + + // Compute the `offsets` and `sizes` to use for tiling. + SmallVector offsets, sizes; + std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp( + rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads); SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), - destinationTensors, tiledResults, resultOffsets, + if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, + innerDestinationTensors, tiledResults, resultOffsets, resultSizes))) return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : - llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, + llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); @@ -562,41 +618,48 @@ static LogicalResult generateLoopNestUsingForallOp( destinationTensor, resultOffset, resultSize, resultStride); } - return success(); + return loops; } /// Generate the tile-loop nest using the loop construct specifed in `options`. /// - `options`: Tiling options specified. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `outerDestinationTensors` are the init values to use for the outer most +/// loop. /// - `yieldTiledValuesFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNest( - RewriterBase &rewriter, Location loc, - scf::SCFTilingOptions::LoopType loopType, ArrayRef loopRanges, - ArrayRef tileSizes, ArrayRef numThreads, - ValueRange destinationTensors, ArrayRef mappingVector, - YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { +/// Returns the generated loops on success. +static FailureOr> generateLoopNest( + RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, + ArrayRef loopRanges, ArrayRef givenTileSizes, + ArrayRef numThreads, ValueRange destinationTensors, + GenerateTiledBodyFn tiledBodyFn) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. - if (llvm::all_of(tileSizes, isZeroInteger)) { + if (llvm::all_of(givenTileSizes, isZeroInteger)) { SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, - tiledResults, resultOffsets, resultSizes); + auto tileOffsets = + llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; }); + auto tileSizes = + llvm::map_to_vector(loopRanges, [](Range r) { return r.size; }); + if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes, + destinationTensors, tiledResults, resultOffsets, + resultSizes))) { + return failure(); + } + return SmallVector{}; } - if (loopType == scf::SCFTilingOptions::LoopType::ForOp) { - return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, - destinationTensors, tiledBodyFn, loops); + if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { + return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes, + destinationTensors, tiledBodyFn); } - if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) { + if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { return generateLoopNestUsingForallOp( - rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector, - destinationTensors, tiledBodyFn, loops); + rewriter, loc, loopRanges, givenTileSizes, numThreads, + options.mappingVector, destinationTensors, tiledBodyFn); } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } @@ -604,7 +667,7 @@ static LogicalResult generateLoopNest( static FailureOr> createInitialTensorsForTiling( RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef iterationDomain, - ArrayRef numThreads, ArrayRef tileSizes, + ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector initTensors; Location loc = op->getLoc(); @@ -626,7 +689,7 @@ static FailureOr> createInitialTensorsForTiling( AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2)); AffineExpr divExpr = s0.ceilDiv(s1); for (auto [index, domain, tileSize] : - llvm::enumerate(iterationDomain, tileSizes)) { + llvm::enumerate(iterationDomain, givenTileSizes)) { if (!numThreads.empty()) { // Untiled case. if (isConstantIntValue(numThreads[index], 0)) { @@ -672,7 +735,7 @@ static SmallVector getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, + ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector splitReductionIvs; splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0)); @@ -689,7 +752,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc, } splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( rewriter, loc, divExpr, - ArrayRef{ivs[ivIndex++], tileSizes[reductionDim]}); + ArrayRef{ivs[ivIndex++], givenTileSizes[reductionDim]}); } } return splitReductionIvs; @@ -701,7 +764,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, + ArrayRef givenTileSizes, const SetVector &reductionDims) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getTiledImplementation(rewriter, offsets, sizes); @@ -717,7 +780,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, - numThreads, tileSizes, reductionDims); + numThreads, givenTileSizes, reductionDims); return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy, regionIterArg, offsets, sizes, reductionDims, splitReductionIvs); @@ -728,7 +791,8 @@ static LogicalResult getResultTilePosition( int64_t index, Value tiledResult, TilingInterface op, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, const SetVector &reductionDims, + ArrayRef givenTileSizes, + const SetVector &reductionDims, SmallVector &resultOffset, SmallVector &resultSize) { @@ -744,7 +808,7 @@ static LogicalResult getResultTilePosition( } SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, - numThreads, tileSizes, reductionDims); + numThreads, givenTileSizes, reductionDims); return redOp.getPartialResultTilePosition( rewriter, index, reductionStrategy, offsets, sizes, reductionDims, splitReductionIvs, resultOffset, resultSize); @@ -999,20 +1063,20 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, SmallVector iterationDomain = op.getIterationDomain(rewriter); // 2. Materialize the tile sizes and/or number of threads; - SmallVector tileSizes, numThreads; - std::tie(tileSizes, numThreads) = + SmallVector givenTileSizes, numThreads; + std::tie(givenTileSizes, numThreads) = getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); // Check if it is safe to tile. This is hold over from previous iterations // of tile to for-all. Consider dropping it. if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy, - tileSizes, numThreads))) { + givenTileSizes, numThreads))) { return failure(); } // Get the reduction dims SetVector reductionDims = - getSanitizedReductionDims(tileSizes, options); + getSanitizedReductionDims(givenTileSizes, options); // 3. If there is an interchange specified, permute the iteration domain and // the tile sizes. @@ -1024,7 +1088,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, "expected interchange vector to be a permutation"); applyPermutationToVector(iterationDomain, interchangeVector); - applyPermutationToVector(tileSizes, interchangeVector); + applyPermutationToVector(givenTileSizes, interchangeVector); if (!numThreads.empty()) applyPermutationToVector(numThreads, interchangeVector); } @@ -1032,24 +1096,21 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, FailureOr tilingResult; // 4. Define the lambda function used later to generate the body of the // innermost tiled loop. - YieldTiledValuesFn innerYieldTiledValuesFn = + GenerateTiledBodyFn innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, + ArrayRef tileOffsets, ArrayRef tileSizes, ValueRange regionIterArgs, SmallVector &tiledResults, SmallVector> &resultOffsets, SmallVector> &resultSizes) -> LogicalResult { - // 4a. Compute the `offsets` and `sizes` to use for tiling. - SmallVector offsets, sizes; - std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, options.reductionStrategy, ivs, iterationDomain, - tileSizes, numThreads, reductionDims); - // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. + SmallVector tileOffsetsVec = llvm::to_vector(tileOffsets); + SmallVector tileSizesVec = llvm::to_vector(tileSizes); if (!interchangeVector.empty()) { auto inversePermutation = invertPermutationVector(interchangeVector); - applyPermutationToVector(offsets, inversePermutation); - applyPermutationToVector(sizes, inversePermutation); + applyPermutationToVector(tileOffsetsVec, inversePermutation); + applyPermutationToVector(tileSizesVec, inversePermutation); } // 5. Generate the tiled implementation within the inner most loop. @@ -1061,7 +1122,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 5b. Early return cloned op if tiling is not happening. We can not // return the original op because it could lead to `rewriter.replaceOp(op, // op->getResults())` and users would get crash. - if (llvm::all_of(tileSizes, isZeroInteger)) { + if (llvm::all_of(givenTileSizes, isZeroInteger)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), @@ -1070,9 +1131,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // 5c. Tile the cloned operation. - tilingResult = getTiledImplementation( - rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets, - sizes, ivs, numThreads, tileSizes, reductionDims); + tilingResult = + getTiledImplementation(rewriter, clonedOp, options.reductionStrategy, + regionIterArgs, tileOffsetsVec, tileSizesVec, + ivs, numThreads, givenTileSizes, reductionDims); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); @@ -1089,8 +1151,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, SmallVector resultOffset, resultSize; if (failed(getResultTilePosition( rewriter, options.reductionStrategy, index, tiledValue, op, - offsets, sizes, ivs, numThreads, tileSizes, reductionDims, - resultOffset, resultSize))) { + tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes, + reductionDims, resultOffset, resultSize))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } @@ -1107,7 +1169,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 6. Find the destination tensors to use for the operation. FailureOr> maybeInits = createInitialTensorsForTiling( rewriter, op, options.reductionStrategy, iterationDomain, numThreads, - tileSizes, reductionDims); + givenTileSizes, reductionDims); if (failed(maybeInits)) { return rewriter.notifyMatchFailure( op, "unable to create initial tensors for tiling"); @@ -1116,13 +1178,16 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; - if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType, - iterationDomain, tileSizes, numThreads, - initTensors, options.mappingVector, - innerYieldTiledValuesFn, loops))) - return op.emitOpError("failed to generate tiling loops"); - assert(succeeded(tilingResult) && - "expected tiling result to be computed after loop generation"); + { + FailureOr> loopsOr = generateLoopNest( + rewriter, op.getLoc(), options, iterationDomain, givenTileSizes, + numThreads, initTensors, innerYieldTiledValuesFn); + if (failed(loopsOr)) + return op.emitOpError("failed to generate tiling loops"); + assert(succeeded(tilingResult) && + "expected tiling result to be computed after loop generation"); + std::swap(loops, loopsOr.value()); + } if (loops.empty()) { // If loops are empty, the tiled op is used as the replacement for the From b016e7a6a268b32f8d41014df309fe2aef903a92 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 17 Sep 2025 21:48:50 -0700 Subject: [PATCH 2/2] [mlir][SCF] Allow using a custom operation to generate loops with `mlir::tileUsingSCF`. This change adds an option to use a custom operation to generate the inter-tile loops during tiling. When the loop type is set to `scf::SCFTilingOptions::LoopType::CustomOp`, the method `mlir::tileUsingSCF` provides two callback functions 1. First one to generate the header of the loop. 2. Second one to generate the terminator of the loop. These methods receive the information needed to generate the loops/terminator and expect to return information needed to generate the code for the intra-tile computation. See comments for more details. Presently this is adds support only for tiling. Subsequent commits will update this to add support for fusion as well. Signed-off-by: MaheshRavishankar --- .../SCF/Transforms/TileUsingInterface.h | 94 ++++++++++- .../SCF/Transforms/TileUsingInterface.cpp | 57 +++++++ .../TilingInterface/tile-using-custom-op.mlir | 60 +++++++ .../TestTilingInterfaceTransformOps.cpp | 148 ++++++++++++++++++ .../TestTilingInterfaceTransformOps.td | 23 +++ 5 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 117e1ce1371f2..6b05ade37881c 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -34,7 +34,7 @@ using SCFTileSizeComputationFunction = /// Options to use to control tiling. struct SCFTilingOptions { /// Specify which loop construct to use for tile and fuse. - enum class LoopType { ForOp, ForallOp}; + enum class LoopType { ForOp, ForallOp, CustomOp }; LoopType loopType = LoopType::ForOp; SCFTilingOptions &setLoopType(LoopType type) { loopType = type; @@ -121,6 +121,98 @@ struct SCFTilingOptions { reductionDims.insert(dims.begin(), dims.end()); return *this; } + + //-------------------------------------------------------------------------// + // Options related to tiling using custom loop. + //-------------------------------------------------------------------------// + + // For generating the inter-tile loops using a custom loop, two callback + // functions are needed + // 1. That generates the "loop header", i.e. the loop that iterates over the + // different tiles. + // 2. That generates the loop terminator + // + // For `scf.forall` case the call back to generate loop header would generate + // + // ```mlir + // scf.forall (...) = ... { + // .. + // } + // ``` + // + // and the call back to generate the loop terminator would generate the + // `scf.in_parallel` region + // + // ```mlir + // scf.forall (...) = ... { + // scf.in_parallel { + // tensor.parallel_insert_slice ... + // } + // } + // ``` + // + + // Information that is to be returned by the callback to generate the loop + // header needed for the rest of the tiled codegeneration. + // - `loops`: The generated loops + // - `tileOffset`: The values that represent the offset of the iteration space + // tile + // - `tileSizes` : The values that represent the size of the iteration space + // tile. + // - `destinationTensors` : The tensors to use as destinations during tiling. + struct CustomLoopHeaderInfo { + SmallVector loops; + SmallVector tileOffset; + SmallVector tileSizes; + SmallVector destinationTensors; + }; + + // Type of the callback function that generates the loop headers. + // - `loopRanges` : Values that represent the full size of the iteration space + // being tiled. + // - `giveTileSizes` : The tile sizes that are to be used to tile the + // iteration + // space. + // - `destinationTensors` : The tensors to use as destinations for the results + // of the tiled loop for loops that implement + // `DestinationStyleOpInterface`. + // Returns the `CustomLoopHeaderInfo` object (described above). it is expected + // that this function sets the insertion point of `rewriter` to the program + // point where the intra-tile loop computation is to be generated. + using GenerateLoopHeaderFn = std::function( + RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, ValueRange destinationTensors)>; + + // Type of the callback function that generates the loop terminator. + // - `tiledResults` : Tiles of the result computed for the iteration space + // tile + // - `resultOffsets` : For each of the `tiledResults`, the offset at which + // the result tile is to be "inserted" back into the + // destination tensor. + // - `resultSizes` : For each of the `tiledResults`, the size of the result + // tile + // that is to be "inserted" back into the destination + // tensor. + // Returns the `CustomLoopHeaderInfo` object (described above) + using GenerateLoopTerminatorFn = std::function> resultOffsets, + ArrayRef> resultSizes, + ValueRange destinationTensors)>; + + // Callback function to generate the inter-tile loop header. + GenerateLoopHeaderFn generateLoopHeaderFn = nullptr; + // Callback function to generate the inter-tile loop terminator. + GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr; + // Helper function to set the callbacks for inter-tile loop header and + // terminator functions when using a custom operation for the loop. + SCFTilingOptions & + setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn, + GenerateLoopTerminatorFn terminatorFn) { + generateLoopHeaderFn = std::move(headerFn); + generateLoopTerminatorFn = std::move(terminatorFn); + return *this; + } }; /// Transformation information returned after tiling. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index b77f66b701927..c3899473289e2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -621,6 +621,57 @@ generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, return loops; } +/// Generate the tile-loop nest using custom loop operation. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `mappingVector` is the mapping attributes to use for loop construction. +/// Can be empty. +/// - `tiledBodyFn` is called to generated the loop body of the inner +/// most +/// loop. +/// Returns the generated `scf.forall` loop on success. +static FailureOr> +generateLoopNestUsingCustomOp( + RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, ValueRange outerDestinationTensors, + const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, + const scf::SCFTilingOptions::GenerateLoopTerminatorFn + &generateLoopTerminatorFn, + GenerateTiledBodyFn tiledBodyFn) { + assert(!loopRanges.empty() && "unexpected empty loop ranges"); + assert(loopRanges.size() == givenTileSizes.size() && + "expected as many tile sizes as loop ranges"); + assert(generateLoopHeaderFn && generateLoopTerminatorFn && + "expected loop header/terminator generation function"); + OpBuilder::InsertionGuard guard(rewriter); + + FailureOr loopHeaderInfo = + generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes, + outerDestinationTensors); + if (failed(loopHeaderInfo)) { + return failure(); + } + + SmallVector ivs; + SmallVector tiledResults; + SmallVector> resultOffsets, resultSizes; + if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset, + loopHeaderInfo->tileSizes, + loopHeaderInfo->destinationTensors, tiledResults, + resultOffsets, resultSizes))) { + return failure(); + } + + if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults, + resultOffsets, resultSizes, + loopHeaderInfo->destinationTensors))) { + return failure(); + } + + return loopHeaderInfo->loops; +} + /// Generate the tile-loop nest using the loop construct specifed in `options`. /// - `options`: Tiling options specified. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. @@ -661,6 +712,12 @@ static FailureOr> generateLoopNest( rewriter, loc, loopRanges, givenTileSizes, numThreads, options.mappingVector, destinationTensors, tiledBodyFn); } + if (options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) { + return generateLoopNestUsingCustomOp( + rewriter, loc, loopRanges, givenTileSizes, destinationTensors, + options.generateLoopHeaderFn, options.generateLoopTerminatorFn, + tiledBodyFn); + } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir new file mode 100644 index 0000000000000..d335e9c3fb5d0 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --mlir-print-local-scope %s | FileCheck %s + +module { + func.func @generic_parallel(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %empty = tensor.empty(%d0, %d1) : tensor + %generic = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) outs(%empty : tensor) { + ^bb(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + return %generic : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops {["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %tiled_op, %loop = transform.test.tile_using_custom_loop %op tile_sizes = [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @generic_parallel +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor +// CHECK-DAG: %[[NITERS0:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 10)>()[%[[D0]]] +// CHECK-DAG: %[[NITERS1:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 20)>()[%[[D1]]] +// CHECK-DAG: %[[NITERS:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 ceildiv 10) * (s1 ceildiv 20))>()[%[[D0]], %[[D1]]] +// CHECK: %[[FOR:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[NITERS]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[EMPTY]]) +// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[NITERS0]], %[[NITERS1]]) +// CHECK-DAG: %[[SIZE0:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>(%[[DELINEARIZE]]#0)[%[[D0]]] +// CHECK-DAG: %[[SIZE1:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>(%[[DELINEARIZE]]#1)[%[[D1]]] +// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply affine_map<(d0) -> (d0 * 10)>(%[[DELINEARIZE]]#0) +// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply affine_map<(d0) -> (d0 * 20)>(%[[DELINEARIZE]]#1) +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[OFFSET1]]] [%[[SIZE1]]] [1] +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[INIT_SLICE]] : +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[INIT]] +// CHECK-SAME: [%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK: scf.yield %[[INSERT_SLICE]] +// CHECK: return %[[FOR]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 3d24d4ecc4d0d..1e3d5371f1ea8 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -468,6 +469,153 @@ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply( : DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestTileAndFuseOuterParallelPartialReduction +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( + TransformRewriter &transformRewriter, TransformResults &transformResults, + TransformState &state) { + auto target = + dyn_cast(*state.getPayloadOps(getRootOp()).begin()); + if (!target) { + emitOpError("expected root operation to implement `TilingInterface`"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + OpFoldResult oneOfr = transformRewriter.getIndexAttr(1); + + scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn = + [&](RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, + ValueRange outerDestinationTensors) + -> FailureOr { + // Check that the strides are all 1 (to make it easier in the test). + if (llvm::any_of(loopRanges, [](Range r) { + return !isConstantIntValue(r.stride, 1); + })) { + return emitOpError("unable to handle loop ranges with strides != 1"); + } + // For testing disallow any of the tile sizes being 0. + if (llvm::any_of(givenTileSizes, isZeroInteger)) { + return emitOpError("unhandled case of zero tile size"); + } + // For testing, only handle tensor tiling. + if (outerDestinationTensors.empty()) { + return emitOpError("expected destination tensors"); + } + + // Compute the number of iterations for each of the loops. + AffineExpr s0, s1, s2; + bindSymbols(rewriter.getContext(), s0, s1, s2); + AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize + + SmallVector allNumIters; + allNumIters.reserve(loopRanges.size()); + for (auto [loopRange, tileSize] : + llvm::zip_equal(loopRanges, givenTileSizes)) { + OpFoldResult numIters = affine::makeComposedFoldedAffineApply( + rewriter, loc, numItersExpr, + {loopRange.offset, loopRange.size, tileSize}); + allNumIters.push_back(numIters); + } + if (allNumIters.empty()) { + return emitOpError("unhandled case where all tile sizes are zero"); + } + + AffineExpr mulExpr = s0 * s1; + OpFoldResult cummulative = oneOfr; + for (auto numIters : allNumIters) { + cummulative = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulExpr, {cummulative, numIters}); + } + + Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cummulative); + + SmallVector offsets; + SmallVector sizes; + SmallVector innerDestinationTensors; + offsets.reserve(loopRanges.size()); + sizes.reserve(loopRanges.size()); + + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize + AffineMap minMap = + AffineMap::get(1, 2, {s0 - d0, s1}, + rewriter.getContext()); // min(ub - offset, tileSize) + auto forOp = scf::ForOp::create( + rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors, + [&](OpBuilder &b, Location bodyLoc, Value linearizedIv, + ValueRange destinations) { + auto delinearizeOp = affine::AffineDelinearizeIndexOp::create( + b, bodyLoc, linearizedIv, allNumIters); + for (auto [normalizedIv, range, tileSize] : llvm::zip_equal( + delinearizeOp.getResults(), loopRanges, givenTileSizes)) { + + OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv); + OpFoldResult offset = affine::makeComposedFoldedAffineApply( + b, bodyLoc, offsetExpr, + {normalizedIvOfr, range.offset, tileSize}); + offsets.push_back(offset); + + OpFoldResult size = affine::makeComposedFoldedAffineMin( + b, bodyLoc, minMap, {offset, range.size, tileSize}); + sizes.push_back(size); + } + innerDestinationTensors = llvm::to_vector(destinations); + }); + rewriter.setInsertionPointToEnd(forOp.getBody()); + return scf::SCFTilingOptions::CustomLoopHeaderInfo{ + {cast(forOp.getOperation())}, + offsets, + sizes, + innerDestinationTensors}; + }; + + scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn = + [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults, + ArrayRef> resultOffsets, + ArrayRef> resultSizes, + ValueRange destinationTensors) -> LogicalResult { + SmallVector yieldValues; + yieldValues.reserve(destinationTensors.size()); + for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal( + tiledResults, resultOffsets, resultSizes, destinationTensors)) { + SmallVector strides(offsets.size(), oneOfr); + Value insertedVal = tensor::InsertSliceOp::create( + rewriter, loc, tiledResult, destination, offsets, sizes, strides); + yieldValues.push_back(insertedVal); + } + scf::YieldOp::create(rewriter, loc, yieldValues); + return success(); + }; + + scf::SCFTilingOptions tilingOptions; + SmallVector staticTileSizes = + extractFromIntegerArrayAttr(getTileSizes()); + SmallVector tileSizes = + getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes); + tilingOptions.setTileSizes(tileSizes) + .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp) + .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn); + + OpBuilder::InsertionGuard g(transformRewriter); + transformRewriter.setInsertionPoint(target); + FailureOr tiledResults = + scf::tileUsingSCF(transformRewriter, target, tilingOptions); + if (failed(tiledResults)) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + transformRewriter.replaceOp(target, tiledResults->replacements); + transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps); + transformResults.set(getOperation()->getResult(1), tiledResults->loops); + + return DiagnosedSilenceableFailure::success(); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 58ccd30bb99a2..694c4229eef62 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -150,4 +150,27 @@ def TestTileAndFuseOuterParallelPartialReductionOp : Op< }]; } +def TestTileUsingCustomLoopOp : Op< + Transform_Dialect, "test.tile_using_custom_loop", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Test Transform op to tile an operation using custom loops. + + The test just folds all the loops and into a single loop and then + delinearizes the indices. + }]; + + let arguments = (ins TransformHandleTypeInterface:$root_op, + DefaultValuedAttr:$tile_sizes); + let results = (outs TransformHandleTypeInterface:$tiled_ops, + Variadic:$loops); + + let assemblyFormat = [{ + $root_op `tile_sizes` `=` $tile_sizes + attr-dict `:` functional-type(operands, results) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS