Skip to content

[mlir][Linalg] Add a pattern to fold concats of fill. #98995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

If a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the outs operands.

@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

If a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the outs operands.


Full diff: https://github.com/llvm/llvm-project/pull/98995.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+56-5)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+27)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..c26eabd811c65 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
+      return failure();
+    }
+
+    auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+    if (!firstFillOp) {
+      return failure();
+    }
+    // Prefetch the fill value.
+    OpFoldResult firstFillVal =
+        getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+    // Collect all the outs values for the fill operations.
+    SmallVector<Value> allOuts;
+    allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+    auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
+      auto fillOp = v.getDefiningOp<linalg::FillOp>();
+      if (!fillOp) {
+        return false;
+      }
+
+      OpFoldResult fillVal =
+          getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
+      if (fillVal != firstFillVal)
+        return false;
+
+      allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
+      return true;
+    };
+    if (!llvm::all_of(concatOperands.drop_front(),
+                      isDefinedByCompatibleFillOp)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "not all operands are defined by a compatible fill op");
+    }
+
+    Value outsConcat = rewriter.create<tensor::ConcatOp>(
+        concatOp.getLoc(), concatOp.getDim(), allOuts);
+    rewriter.replaceOpWithNewOp<linalg::FillOp>(
+        concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
-           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-           FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-           FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
+              FoldFillWithPack, FoldFillWithPad,
+              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+              FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..02b46c405e2bd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1096,3 +1096,30 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+
+// -----
+
+func.func @concats_of_fill(
+    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> tensor<5x?x?xf32>
+{
+  %cst0 = arith.constant 0.0 : f32
+  %cst1 = arith.constant 0.0 : f32
+  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+  %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  return %4 : tensor<5x?x?xf32>
+}
+//       CHECK: func @concats_of_fill(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[CST:.+]] = arith.constant 0.0
+//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+//       CHECK:   %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
+//       CHECK:   return %[[FILL]]

}

auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
if (!firstFillOp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where this check would then become redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow how that becomes redundant.

SmallVector<Value> allOuts;
allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());

auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like you could have used this lambda on the first value, too, instead of duplicated the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the first operand is checked to see if this is even a candidate, i.e. if the first operand is not a fill, then we early exist. For the rest of the operands we compare with the first (I find it strange comparing an op with itself...)

Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};

// Fold a concat with all elements being fills of the same value
// into a fill of the concat result shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use /// in top level comment

If a concat has all its operands as just fills, and the values match,
then the fill could happen on the concatenated values of the `outs`
operands.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar merged commit d9c8533 into llvm:main Jul 30, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants