-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for #120118
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesFull diff: https://github.com/llvm/llvm-project/pull/120118.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..081bf9b6d3b239 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
- TransformHandleTypeInterface:$split_linalg_op,
- TransformHandleTypeInterface:$combining_linalg_op,
+ TransformHandleTypeInterface:$split_op,
+ TransformHandleTypeInterface:$combining_op,
TransformHandleTypeInterface:$for_op);
let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
+ Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..89122ad56a9f57 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2627,12 +2627,19 @@ void transform::TileReductionUsingForOp::build(
}
DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp target,
+ transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
+
+ auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+ if (!partialReductionOp) {
+ return emitSilenceableFailure(
+ target->getLoc(),
+ "Operation should implement PartialReductionOpInterface");
+ }
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
- rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+ rewriter, partialReductionOp,
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
if (failed(result))
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test maybe?
I don't really know what kind of additional test to add here. The only op that implements this interface upstream is LinalgOp. The fact that it still works and passes previous tests is a test in itself. But i can add a operation to test dialect which implements partial reduction op interface. Not sure if that's something we want though. What do you think? |
The API used internally expects PartialReductionOpInterface. This patch allows any operation implementing this interface to use this transform op (instead of just LinalgOp).