-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][SCF] Add loops as parameter to LoopTerminator callback when using CustomOp. #161386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (sebvince) ChangesThis PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops. Full diff: https://github.com/llvm/llvm-project/pull/161386.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 668ee6386f71f..7c735d825b445 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -183,6 +183,7 @@ struct SCFTilingOptions {
ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
// Type of the callback function that generates the loop terminator.
+ // - `loops` : generated loops from the GenerateLoopHeaderFn callback
// - `tiledResults` : Tiles of the result computed for the iteration space
// tile.
// - `resultOffsets` : For each of the `tiledResults`, the offset at which
@@ -193,7 +194,8 @@ struct SCFTilingOptions {
// tensor.
// Returns the `CustomLoopHeaderInfo` object (described above)
using GenerateLoopTerminatorFn = std::function<LogicalResult(
- RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+ RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
+ ValueRange tiledResults,
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ValueRange destinationTensors)>;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 89e2c57d709dd..36685d3affe03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -665,8 +665,8 @@ generateLoopNestUsingCustomOp(
return failure();
}
- if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
- resultOffsets, resultSizes,
+ if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
+ tiledResults, resultOffsets, resultSizes,
loopHeaderInfo->destinationTensors))) {
return failure();
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7981c72c2f2c8..326fec3ee5cf0 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -581,7 +581,8 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
};
scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
- [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+ [&](RewriterBase &rewriter, Location loc,
+ ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
ValueRange destinationTensors) -> LogicalResult {
|
Where is this actually tested? Right now this looks like a dead arguments in-tree. |
Yes |
…ing CustomOp. (#161386) This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.
…ing CustomOp. (llvm#161386) This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.
This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.