Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] Extend option to yield replacement for multiple results case #93144

Merged
merged 5 commits into from
Jun 28, 2024

Conversation

Yun-Fly
Copy link
Contributor

@Yun-Fly Yun-Fly commented May 23, 2024

Currently, we only have an option to yield replacement for fusableProducer like this:

%1 = op1
%2 = op2(%1)
%3, %y2 = scf.for(%dest1, %dest2){
   %tiled_2 = tiled_op2(...)
   %tiled_3 = tiled_op3(%tiled_2)
   %ins_3 = insert %tiled_3 into %dest1
   %ins_2 = insert %tiled_2 into %dest2 // also yield tiled value, which is decided by caller, usually when tiled value has multiple uses.
   yield  %ins_3, %ins_2
}
// another user of %2
%4 = op4(%y2)

However, it has no chance to yield replacement for multiple results as followed:

%1 = op1
%2_1, %2_2  = op2(%1)
%3, %y2_1. %y2_2 = scf.for(%dest1, %dest2, %dest3){
   %t2_1, %t2_2 = tiled_op2(...)
   %t3 = tiled_op3(%t2_1)
   %ins3 = insert %t3 into %dest1
   %ins2_1 = insert %t2_1 into %dest2
   %ins2_2 = insert %t2_2 into %dest3
   yield  %ins3, %ins2_1, %ins2_2
}
// another user of %2_1
%4 = op4(%y2_1)
// user of %2_2
%5 = op5(%y2_2)

With this method, the original untiled op2 will has no uses any more and expect cleaned up later, otherwise leading an unnecessary computation.

Based on the earlier talk with @MaheshRavishankar in discourse, this PR extends the functionality of yielding replacement for multiple results case. NOTE that, it is still decided by the caller whether need to yield replacement as same as current status.

Two major changes:

  1. extract a new tiling interface called getIterationDomainTileFromResultTile, which is used to compute other results tile according candidate sliceOp. BTW, this utility is much similar to another one named getIterationDomainTileFromOperandTile in this PR. I think they can be further unified when finally merged.
  2. enhance yieldReplacementForFusedProducer to deal with multiple OpResults and add another optional argument called yieldResultNumber indicating which result need yield. If not given, all of results will be yield by default.

Considering downstream impact, not sure its better to break down current yieldReplacement option and add another new one for fusionControlFn?

@MaheshRavishankar would you help to review this PR? Thanks.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented May 23, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (Yun-Fly)

Changes

Currently, we only have an option to yield replacement for fusableProducer like this:

%1 = op1
%2 = op2(%1)
%3, %y2 = scf.for(%dest1, %dest2){
   %tiled_2 = tiled_op2(...)
   %tiled_3 = tiled_op3(%tiled_2)
   %ins_3 = insert %tiled_3 into %dest1
   %ins_2 = insert %tiled_2 into %dest2 // also yield tiled value, which is decided by caller, usually when tiled value has multiple uses.
   yield  %ins_3, %ins_2
}
// another user of %2
%4 = op4(%y2)

However, it has no chance to yield replacement for multiple results as followed:

%1 = op1
%2_1, %2_2  = op2(%1)
%3, %y2_1. %y2_2 = scf.for(%dest1, %dest2, %dest3){
   %t2_1, %t2_2 = tiled_op2(...)
   %t3 = tiled_op3(%t2_1)
   %ins3 = insert %t3 into %dest1
   %ins2_1 = insert %t2_1 into %dest2
   %ins2_2 = insert %t2_2 into %dest3
   yield  %ins3, %ins2_1, %ins2_2
}
// another user of %2_1
%4 = op4(%y2_1)
// user of %2_2
%5 = op5(%y2_2)

With this method, the original untiled op2 will has no uses any more and expect cleaned up later, otherwise leading an unnecessary computation.

Based on the earlier talk with @MaheshRavishankar in discourse, this PR extends the functionality of yielding replacement for multiple results case. NOTE that, it is still decided by the caller whether need to yield replacement as same as current status.

Two major changes:

  1. extract a new tiling interface called getIterationDomainTileFromResultTile, which is used to compute other results tile according candidate sliceOp. BTW, this utility is much similar to another one named getIterationDomainTileFromOperandTile in this PR. I think they can be further unified when finally merged.
  2. enhance yieldReplacementForFusedProducer to deal with multiple OpResults and add another optional argument called yieldResultNumber indicating which result need yield. If not given, all of results will be yield by default.

Considering downstream impact, not sure its better to break down this option and add another new one for current fusionControlFn?

@MaheshRavishankar would you help to review this PR? Thanks.


Full diff: https://github.com/llvm/llvm-project/pull/93144.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+5-1)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+19)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+26-6)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+110-38)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir (+62)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6d567171e185a..32249b90644a8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -190,10 +190,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
 LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 14d775d986d20..5ac8eaac402b2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return tile offset and size of Iteration Domain 
+          based on the given tile info from the certain result.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromResultTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$resultOffsets,
+          "ArrayRef<OpFoldResult> ":$resultSizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f512be46cc13d..b46c8135d1c7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -160,10 +160,11 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<TilingResult>
-  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
-                          ArrayRef<OpFoldResult> offsets,
-                          ArrayRef<OpFoldResult> sizes) const {
+  LogicalResult getIterationDomainTileFromResultTile(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> resultOffsets, ArrayRef<OpFoldResult> resultSizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Check that the indexing map used for the output is a projected
@@ -193,8 +194,27 @@ struct LinalgOpTilingInterface
     for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
       unsigned dimPosition =
           cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+      iterationTileOffsets[dimPosition] = resultOffsets[resultExpr.index()];
+      iterationTileSizes[dimPosition] = resultSizes[resultExpr.index()];
+    }
+
+    iterDomainOffsets = iterationTileOffsets;
+    iterDomainSizes = iterationTileSizes;
+
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+
+    SmallVector<OpFoldResult> iterationTileOffsets, iterationTileSizes;
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromResultTile(
+            b, resultNumber, offsets, sizes, iterationTileOffsets,
+            iterationTileSizes))) {
+      return failure();
     }
 
     FailureOr<TilingResult> tilingResult =
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a72dafe725177..ddd0e94f9bd4c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -939,49 +939,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber) {
   if (loops.empty())
     return success();
 
-  OpResult fusableProducer = fusedProducerInfo.origProducer;
-  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
-  FailureOr<Value> initValue = tensor::getOrCreateDestination(
-      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
-  if (succeeded(initValue)) {
-
-    YieldTiledValuesFn newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-        -> LogicalResult {
-      OpBuilder::InsertionGuard g(innerRewriter);
-      if (auto tiledDestStyleOp =
-              tiledAndFusedProducer
-                  .getDefiningOp<DestinationStyleOpInterface>()) {
-        rewriter.setInsertionPoint(tiledDestStyleOp);
-        Value newRegionArg = newRegionIterArgs.back();
+  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+            *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+  Location loc = originalOwner->getLoc();
+  // a. collect all init Value to be appended
+  ArrayRef<unsigned> initNumberList =
+      yieldResultNumber ? yieldResultNumber.value()
+                        : llvm::to_vector(llvm::seq<unsigned>(
+                              0, originalOwner->getNumResults()));
+  SmallVector<Value> initValueList;
+  for (const auto &resultNumber : initNumberList) {
+    FailureOr<Value> initValue = tensor::getOrCreateDestination(
+        rewriter, loc, originalOwner->getResult(resultNumber));
+    if (succeeded(initValue)) {
+      initValueList.push_back(initValue.value());
+    } else {
+      return failure();
+    }
+  }
+
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+
+    // get sliceOp tile information
+    SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+                              sliceSizes = sliceOp.getMixedSizes();
+
+    // expect all strides of sliceOp being 1
+    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return !isConstantIntValue(ofr, 1);
+        }))
+      return failure();
+
+    unsigned sliceResultNumber =
+        fusedProducerInfo.origProducer.getResultNumber();
+
+    auto tilableOp = cast<TilingInterface>(originalOwner);
+    // b. get iterDomain Offset and Sizes based on sliceOp tile
+    SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+    // skip tensor.pack/unpack/pad, which expects single opResult
+    if (tilableOp->getNumResults() > 1 &&
+        failed(tilableOp.getIterationDomainTileFromResultTile(
+            rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+            iterDomainOffset, iterDomainSizes))) {
+      return failure();
+    }
+
+    // c. calculate offsets and sizes info of all OpResults respectively based
+    // on iteration Domain Tile
+    SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+    for (const auto &resultNumber : initNumberList) {
+      if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+        offsetList.push_back(sliceOffset);
+        sizesList.push_back(sliceSizes);
+      } else {
+        assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+        // infer result tile according to the iteration domain tile
+        SmallVector<OpFoldResult> offset, sizes;
+        if (failed(tilableOp.getResultTilePosition(
+                rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+                offset, sizes))) {
+          return failure();
+        }
+        offsetList.push_back(offset);
+        sizesList.push_back(sizes);
+      }
+    }
+
+    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    if (auto tiledDestStyleOp =
+            dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
-        unsigned resultNumber = fusableProducer.getResultNumber();
+            loc, newRegionArg, offsetList[index], sizesList[index],
+            SmallVector<OpFoldResult>(offsetList[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        unsigned resultNumber = initNumberList[index];
         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
         });
       }
-      Block *block = rewriter.getInsertionPoint()->getBlock();
-      rewriter.setInsertionPoint(block->getTerminator());
-      tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
-      tiledOffset.emplace_back(sliceOp.getMixedOffsets());
-      tiledSizes.emplace_back(sliceOp.getMixedSizes());
-      return success();
-    };
+    }
 
-    return addInitOperandsToLoopNest(rewriter, loops,
-                                     SmallVector<Value>{initValue.value()},
-                                     newYieldValuesFn);
-  }
-  return success();
+    // e. prepare tiled offset and sizes for later `insert_slice` creation by
+    // caller
+    Block *block = rewriter.getInsertionPoint()->getBlock();
+    rewriter.setInsertionPoint(block->getTerminator());
+    for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+      tiledResult.push_back(tiledOwner->getResult(resultNumber));
+      tiledOffset.emplace_back(offsetList[index]);
+      tiledSizes.emplace_back(sizesList[index]);
+    }
+    return success();
+  };
+
+  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+                                   newYieldValuesFn);
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1071,14 +1136,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       continue;
 
     if (yieldReplacement) {
+      // Reconstruct and yield all opResult of fusableProducerOp by default. The
+      // caller can specific which one to yield by designating optional argument
+      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      Operation *fusableProducerOp = fusableProducer.getOwner();
       if (failed(yieldReplacementForFusedProducer(
               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
         return rewriter.notifyMatchFailure(
-            fusableProducer.getOwner(), "failed to replacement value for this "
-                                        "oepration from within the tiled loop");
+            fusableProducerOp, "failed to replacement value for this "
+                               "operation from within the tiled loop");
+      }
+      for (const auto &result : fusableProducerOp->getResults()) {
+        origValToResultNumber[result] =
+            loops.front()->getNumResults() -
+            (fusableProducerOp->getNumResults() - result.getResultNumber());
       }
-      origValToResultNumber[fusableProducer] =
-          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
 //      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
 //      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+                       %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, 
+                       %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) 
+                       -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+  %out0, %out1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (j, i)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+  outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+  ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+    %4 = arith.mulf %0, %1 : f32
+    %5 = arith.addf %0, %1 : f32
+    linalg.yield %4, %5: f32, f32
+  } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+  %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+  return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_and_yield %add [16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+//      CHECK:   %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME:         ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[INIT2_TILE]] :
+//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly looks right to me. Just looking at this once more.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, looking again this does look right to me. Thanks for the addition!

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_yield_multiouts branch from 4235f25 to 4377bf0 Compare June 5, 2024 13:47
@Yun-Fly Yun-Fly changed the title [mlir][TilingInterface] Extend option to yield replacement for multiple results case [mlir][scf] Extend option to yield replacement for multiple results case Jun 5, 2024
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks very adhoc to me, same as the other PR I just looked at, this seems to want to split the op before fusion instead of adding more code to support ever more complex cases.

@MaheshRavishankar
Copy link
Contributor

This also looks very adhoc to me, same as the other PR I just looked at, this seems to want to split the op before fusion instead of adding more code to support ever more complex cases.

Its not always possible to "split the op" I dont know which other PR you are refering to, but as far as I can see this is OK to me. Please provide more targeted feedback to help navigate.

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 5, 2024

This also looks very adhoc to me, same as the other PR I just looked at, this seems to want to split the op before fusion instead of adding more code to support ever more complex cases.

Could you explain more about split the op before fusion?

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_yield_multiouts branch from eb824e1 to 74d925e Compare June 6, 2024 01:48
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 12, 2024

Hi @MaheshRavishankar @nicolasvasilache, is there any update comments? Or shall we merge this patch?

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_yield_multiouts branch from 74d925e to eda4bf3 Compare June 17, 2024 08:53
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 17, 2024

Hi, @MaheshRavishankar. I have updated document to align with your recent PR #95178 .

If there is no new comment, It is planned to merge this patch ASAP in avoid of one more rebase.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. I think I already approved it. @nicolasvasilache please comment if you have any more comments. If not maybe your comments can be addressed post landing.

@Yun-Fly Yun-Fly merged commit 7ef08ea into llvm:main Jun 28, 2024
7 checks passed
Copy link

@Yun-Fly Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@Yun-Fly Yun-Fly deleted the yunfei/fuse_yield_multiouts branch July 2, 2024 01:13
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…ase (llvm#93144)

This patch extends the functionality of yielding replacement for multiple 
results case and adds another optional argument called `yieldResultNumber` 
indicating which result(s) need yield. If not given, all of results will be yield 
by default.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants