diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 668ee6386f71f..7c735d825b445 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -183,6 +183,7 @@ struct SCFTilingOptions { ArrayRef givenTileSizes, ValueRange destinationTensors)>; // Type of the callback function that generates the loop terminator. + // - `loops` : generated loops from the GenerateLoopHeaderFn callback // - `tiledResults` : Tiles of the result computed for the iteration space // tile. // - `resultOffsets` : For each of the `tiledResults`, the offset at which @@ -193,7 +194,8 @@ struct SCFTilingOptions { // tensor. // Returns the `CustomLoopHeaderInfo` object (described above) using GenerateLoopTerminatorFn = std::function loops, + ValueRange tiledResults, ArrayRef> resultOffsets, ArrayRef> resultSizes, ValueRange destinationTensors)>; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 89e2c57d709dd..36685d3affe03 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -665,8 +665,8 @@ generateLoopNestUsingCustomOp( return failure(); } - if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults, - resultOffsets, resultSizes, + if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops, + tiledResults, resultOffsets, resultSizes, loopHeaderInfo->destinationTensors))) { return failure(); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7981c72c2f2c8..326fec3ee5cf0 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -581,7 +581,8 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( }; scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn = - [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults, + [&](RewriterBase &rewriter, Location loc, + ArrayRef loops, ValueRange tiledResults, ArrayRef> resultOffsets, ArrayRef> resultSizes, ValueRange destinationTensors) -> LogicalResult {