Skip to content

Conversation

sebvince
Copy link
Contributor

This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.

@sebvince sebvince marked this pull request as ready for review September 30, 2025 14:54
@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (sebvince)

Changes

This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.


Full diff: https://github.com/llvm/llvm-project/pull/161386.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+3-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+2-2)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 668ee6386f71f..7c735d825b445 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -183,6 +183,7 @@ struct SCFTilingOptions {
       ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
 
   // Type of the callback function that generates the loop terminator.
+  // - `loops` : generated loops from the GenerateLoopHeaderFn callback
   // - `tiledResults` : Tiles of the result computed for the iteration space
   //                    tile.
   // - `resultOffsets` : For each of the `tiledResults`, the offset at which
@@ -193,7 +194,8 @@ struct SCFTilingOptions {
   //                   tensor.
   // Returns the `CustomLoopHeaderInfo` object (described above)
   using GenerateLoopTerminatorFn = std::function<LogicalResult(
-      RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+      RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
+      ValueRange tiledResults,
       ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
       ArrayRef<SmallVector<OpFoldResult>> resultSizes,
       ValueRange destinationTensors)>;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 89e2c57d709dd..36685d3affe03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -665,8 +665,8 @@ generateLoopNestUsingCustomOp(
     return failure();
   }
 
-  if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
-                                      resultOffsets, resultSizes,
+  if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
+                                      tiledResults, resultOffsets, resultSizes,
                                       loopHeaderInfo->destinationTensors))) {
     return failure();
   }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7981c72c2f2c8..326fec3ee5cf0 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -581,7 +581,8 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
   };
 
   scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
-      [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+      [&](RewriterBase &rewriter, Location loc,
+          ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
           ArrayRef<SmallVector<OpFoldResult>> resultSizes,
           ValueRange destinationTensors) -> LogicalResult {

@MaheshRavishankar MaheshRavishankar merged commit 178651a into llvm:main Sep 30, 2025
14 checks passed
@joker-eph
Copy link
Collaborator

Where is this actually tested? Right now this looks like a dead arguments in-tree.

@sebvince
Copy link
Contributor Author

sebvince commented Oct 1, 2025

Where is this actually tested? Right now this looks like a dead arguments in-tree.

Yes loops is not explicitly tested/used in the generateLoopTerminatorFn callback in the current test.
I've updated the test here : #161490

RiverDave pushed a commit that referenced this pull request Oct 1, 2025
…ing CustomOp. (#161386)

This PR adds to the generateLoopTerminatorFn callback the loops
generated by GenerateLoopHeaderFn. This is needed to correctly set the
insertion point with scf.forall ops.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…ing CustomOp. (llvm#161386)

This PR adds to the generateLoopTerminatorFn callback the loops
generated by GenerateLoopHeaderFn. This is needed to correctly set the
insertion point with scf.forall ops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants