Skip to content

Conversation

bangtianliu
Copy link
Contributor

@bangtianliu bangtianliu commented Sep 10, 2025

Following PR #120118, this PR extends transform.structured.tile_reduction_using_forall so that it can be applied to any operation implementing PartialReductionOpInterface, rather than being restricted to LinalgOp.

Existing tests relevant to linalg ops remain valid:

%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0

Additional tests for non-Linalg operations (e.g., IREE custom ops that implement PartialReductionOpInterface) will be added on the IREE side.

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Bangtian Liu (bangtianliu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/157932.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+11-4)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..8f3232f01544f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
   let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
-                      TransformHandleTypeInterface:$split_linalg_op,
-                      TransformHandleTypeInterface:$combining_linalg_op,
+                      TransformHandleTypeInterface:$split_op,
+                      TransformHandleTypeInterface:$combining_op,
                       TransformHandleTypeInterface:$forall_op);
 
   let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
+        Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..f3db8f7ccfaa1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
 }
 
 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
-    transform::TransformRewriter &rewriter, LinalgOp target,
+    transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
+
+  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+  if (!partialReductionOp) {
+    return emitSilenceableFailure(
+        target->getLoc(),
+        "Operation should implement PartialReductionOpInterface");
+  }
   SmallVector<OpFoldResult> numThreads =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
   SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
       extractFromIntegerArrayAttr<unsigned>(getReductionDims());
   if (reductionDims.empty()) {
     for (auto [idx, iteratorType] :
-         llvm::enumerate(target.getIteratorTypesArray())) {
+         llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
       if (iteratorType == utils::IteratorType::reduction)
         reductionDims.push_back(idx);
     }
   }
   options.setReductionDims(reductionDims);
-  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
-      rewriter, cast<TilingInterface>(target.getOperation()), options);
+  FailureOr<scf::SCFTilingResult> result =
+      scf::tileUsingSCF(rewriter, partialReductionOp, options);
 
   if (failed(result)) {
     auto diag = emitSilenceableError() << "could not tile reduction";

@bangtianliu bangtianliu force-pushed the tile_reduction_using_forall branch from 915facf to f5a938e Compare September 10, 2025 18:54
…n_using_forall

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu marked this pull request as ready for review September 10, 2025 19:18
@bangtianliu bangtianliu merged commit 17723e4 into llvm:main Sep 10, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants