Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for #120118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Dec 16, 2024

The API used internally expects PartialReductionOpInterface. This patch allows any operation implementing this interface to use this transform op (instead of just LinalgOp).

@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/120118.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+9-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..081bf9b6d3b239 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
-                      TransformHandleTypeInterface:$split_linalg_op,
-                      TransformHandleTypeInterface:$combining_linalg_op,
+                      TransformHandleTypeInterface:$split_op,
+                      TransformHandleTypeInterface:$combining_op,
                       TransformHandleTypeInterface:$for_op);
 
   let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
+        Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..89122ad56a9f57 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2627,12 +2627,19 @@ void transform::TileReductionUsingForOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgOp target,
+    transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
+
+  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+  if (!partialReductionOp) {
+    return emitSilenceableFailure(
+        target->getLoc(),
+        "Operation should implement PartialReductionOpInterface");
+  }
   FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
-      rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+      rewriter, partialReductionOp,
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
 
   if (failed(result))

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test maybe?

@Groverkss
Copy link
Member Author

Add a test maybe?

I don't really know what kind of additional test to add here. The only op that implements this interface upstream is LinalgOp. The fact that it still works and passes previous tests is a test in itself.

But i can add a operation to test dialect which implements partial reduction op interface. Not sure if that's something we want though. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants