Skip to content

Conversation

@qedawkins
Copy link
Contributor

This introduces isOpFusableWithProducer/Consumer methods to the TilingInterface that enable querying whether a tilable op can be fused into a given set of producer slices or consumer slice without generating IR. This is needed to enable use of the tiling interface in pattern rewrites, as without this any pattern rewrite that tries to invoke the method to tile is allowed to generate IR and fail.

This introduces `isOpFusableWithProducer/Consumer` methods to the
TilingInterface that enable querying whether a tilable op can be fused
into a given set of producer slices or consumer slice without generating
IR. This is needed to enable use of the tiling interface in pattern
rewrites, as without this any pattern rewrite that tries to invoke the
method to tile is allowed to generate IR and fail.
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

This introduces isOpFusableWithProducer/Consumer methods to the TilingInterface that enable querying whether a tilable op can be fused into a given set of producer slices or consumer slice without generating IR. This is needed to enable use of the tiling interface in pattern rewrites, as without this any pattern rewrite that tries to invoke the method to tile is allowed to generate IR and fail.


Full diff: https://github.com/llvm/llvm-project/pull/166502.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+37)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+46)
  • (added) mlir/test/Interfaces/TilingInterface/query-fusability.mlir (+49)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+58)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+23-1)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index e0516abdfcf0c..c30782a25e40f 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return failure();
         }]
+      >,
+      //===------------------------------------------------------------------===//
+      // Interface methods for querying fusability.
+      //===------------------------------------------------------------------===//
+      InterfaceMethod<
+        /*desc=*/[{
+          Indicates whether it is possible to fuse this operation with the given
+          result slice. This method is not allowed to generate any IR.
+        }],
+        /*retTy=*/"bool",
+        /*methodName=*/"isOpFusableWithConsumerSlice",
+        /*args=*/(ins
+          "unsigned":$resultNumber,
+          "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
+          "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return false;
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Indicates whether it is possible to fuse this operation with the given
+          list of operand slices. This method is not allowed to generate any IR.
+        }],
+        /*retTy=*/"bool",
+        /*methodName=*/"isOpFusableWithProducerSlices",
+        /*args=*/(ins
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return false;
+        }]
       >
   ];
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 57b610b31e964..527878786f50f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -359,6 +359,52 @@ struct LinalgOpTilingInterface
     /// Inline the op payload and store the result.
     return inlinePayload(builder, linalgOp, ivs, indexedValues);
   }
+
+  bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
+                                    ArrayRef<OpFoldResult> offsets,
+                                    ArrayRef<OpFoldResult> sizes) const {
+    return !cast<LinalgOp>(op).getShapesToLoopsMap();
+  }
+
+  bool isOpFusableWithProducerSlices(
+      Operation *op, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+
+    auto linalgOp = cast<LinalgOp>(op);
+    SmallVector<AffineMap> indexingMaps =
+        llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+          OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+          return linalgOp.getMatchingIndexingMap(&opOperand);
+        });
+    // First verify that the iteration domain on operand subranges is well
+    // defined.
+    if (!linalgOp.getShapesToLoopsMap())
+      return false;
+    // Next verify that operand slices are consistent.
+    DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+    for (auto [indexingMap, offsets, sizes] :
+         llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+      for (auto [resultExpr, offset, size] :
+           llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+        auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+        if (!dimExpr)
+          return false;
+        unsigned position = dimExpr.getPosition();
+        auto it = mappedOffsets.find(position);
+        if (it != mappedOffsets.end()) {
+          OpFoldResult seenOffset = it->second;
+          OpFoldResult seenSize = mappedSizes.lookup(position);
+          if (seenOffset != offset || seenSize != size)
+            return false;
+        } else {
+          mappedOffsets[position] = offset;
+          mappedSizes[position] = size;
+        }
+      }
+    }
+    return true;
+  }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
new file mode 100644
index 0000000000000..1fa828c9cd868
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+  %c0 = arith.constant 0 : index
+  %c10 = arith.constant 10 : index
+  %c20 = arith.constant 20 : index
+
+  %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+  %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+  // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}}
+  %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+                       outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+  return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+    transform.test.query_producer_fusability %add : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
+  %c0 = arith.constant 0 : index
+  %c10 = arith.constant 10 : index
+  %c20 = arith.constant 20 : index
+
+  %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+  %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
+
+  // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}}
+  %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
+                       outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>
+
+  return %result : tensor<100x200xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
+    transform.test.query_producer_fusability %add : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 326fec3ee5cf0..d6bb178505d2b 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
@@ -622,6 +623,63 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestQueryProducerFusability
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply(
+    TransformRewriter &rewriter, TransformResults &transformResults,
+    TransformState &state) {
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+    if (!tilingInterfaceOp) {
+      return emitSilenceableError()
+             << "target operation does not implement TilingInterface";
+    }
+
+    // Collect operand numbers and their corresponding producer insert_slice
+    // offsets and sizes.
+    SmallVector<unsigned> operandNumbers;
+    SmallVector<SmallVector<OpFoldResult>> allOffsets;
+    SmallVector<SmallVector<OpFoldResult>> allSizes;
+
+    for (OpOperand &operand : target->getOpOperands()) {
+      Value operandValue = operand.get();
+      Operation *definingOp = operandValue.getDefiningOp();
+
+      // Look for a producer tensor.insert_slice. This is only for testing
+      // purposes and otherwise is not a useful transformation.
+      if (auto insertSliceOp =
+              dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
+        operandNumbers.push_back(operand.getOperandNumber());
+        allOffsets.push_back(insertSliceOp.getMixedOffsets());
+        allSizes.push_back(insertSliceOp.getMixedSizes());
+      }
+    }
+
+    if (!operandNumbers.empty()) {
+      bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices(
+          operandNumbers, allOffsets, allSizes);
+
+      if (isFusable) {
+        target->emitRemark()
+            << "can be fused with producer tensor.insert_slice ops";
+      } else {
+        target->emitRemark()
+            << "cannot be fused with producer tensor.insert_slice ops";
+      }
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestQueryProducerFusability::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsPayload(effects);
+}
+
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.cpp.inc"
 
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 694c4229eef62..4d0998052ba79 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -166,11 +166,33 @@ def TestTileUsingCustomLoopOp : Op<
                    DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results  = (outs TransformHandleTypeInterface:$tiled_ops,
                   Variadic<TransformHandleTypeInterface>:$loops);
-  
+
   let assemblyFormat = [{
     $root_op `tile_sizes` `=` $tile_sizes
     attr-dict `:` functional-type(operands, results)
   }];
 }
 
+def TestQueryProducerFusability : Op<
+    Transform_Dialect, "test.query_producer_fusability",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    Test operation for the producer fusability query method in the
+    TilingInterface.
+
+    For each operation in the target handle, this looks for tensor.insert_slice
+    ops that produce operands to the tilable op. The offset/sizes from those
+    inserts is used as the arguments to `isOpFusableWithProducerSlices` and
+    emits a remark with the result of the query.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` type($target)
+  }];
+}
+
 #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS

/*methodName=*/"isOpFusableWithConsumerSlice",
/*args=*/(ins
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the consumer slice API has been modified to handle multiple results at the same time. I think this needs to do the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getIterationDomainTileFromResultTile still only looks at a single result at a time, so this does too. If producing multiple results via a single tiling call is implemented I would expect the existing method to be updated to reflect that too, so I'd rather keep them consistent right now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I am confused with the terminology. This is basically saying if you can fuse an operation with its consumer slice (as the name correctly reads), but this is for producer fusion, not the consumer fusion cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consumer is in reference to the slice here, so

This:
tilableOp -> consumer

Consumer Fusion:
LoopOp -> consumer tilable op

Consumer is unsurprisingly overloaded. I'm open to better naming suggestions if you have one (ditto below).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope your naming makes sense.

bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
return !cast<LinalgOp>(op).getShapesToLoopsMap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow the logic here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op's iteration space isn't invertible if the ShapesToLoopsMap doesn't exist (e.g. in the case of exotic indexing maps that include mods and divs). In the case of linalg ops support for such maps is unimplemented (we could emit loops directly) but we still need to reflect the unimplemented case here.

ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {

auto linalgOp = cast<LinalgOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I am not sure about why we need this. In theory it is feasible to fuse with any producer, irrespective of what slice of the operand is used. This seems like a cost function, rather than a structural requirement for fusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ShapesToLoopsMap check is for the same reason as above.

In theory it is feasible to fuse with any producer, irrespective of what slice of the operand is used

This is certainly not true for all operations. Not all tensor operands and dimensions are partitionable. For example, the inner most dimension of a tensor.gather must have size equal to the number of gathered dims, meaning we can never fuse into a producer that partitions that dimension. Additionally when fusing along multiple operands, if their slices don't line up, the feasibility of fusion depends on the loop type producing it. In the case of scf.for and scf.forall, it is impossible to fuse without fissioning the loop since we would need to relayout one or more of the mismatched operands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, got confused by the naming. This is for consumer fusion cases I think. Isnt this the same logic that should be here https://github.com/iree-org/llvm-project/blob/7e88b401a4c731fa04524ccbc1542fce56609507/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp#L156 ? Maybe we use that/common this out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it common. When I first wrote this I didn't understand why we were ignoring non-AffineDimExprs in the above matcher and made this one a little more restrictive but I can update the same there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of based on how the ShapesToLoopsMap works. you could have dimensions used in any arbitrary AffineExpr, but the shapes of the loop are determined by those accesses that are accessed through AffineDimExpr. I am doing the same logic in that method

getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
. If you feel the documentation needs to be updated, please do so.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your missing a test for the consumer fusability.

/*methodName=*/"isOpFusableWithConsumerSlice",
/*args=*/(ins
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I am confused with the terminology. This is basically saying if you can fuse an operation with its consumer slice (as the name correctly reads), but this is for producer fusion, not the consumer fusion cases.

ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {

auto linalgOp = cast<LinalgOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, got confused by the naming. This is for consumer fusion cases I think. Isnt this the same logic that should be here https://github.com/iree-org/llvm-project/blob/7e88b401a4c731fa04524ccbc1542fce56609507/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp#L156 ? Maybe we use that/common this out?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants