-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][Linalg] Add a pattern to fold concats of fill. #98995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Add a pattern to fold concats of fill. #98995
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesIf a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the Full diff: https://github.com/llvm/llvm-project/pull/98995.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..c26eabd811c65 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+// Fold a concat with all elements being fills of the same value
+// into a fill of the concat result shape.
+struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOperands = concatOp.getInputs();
+ if (concatOperands.empty()) {
+ return failure();
+ }
+
+ auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
+ if (!firstFillOp) {
+ return failure();
+ }
+ // Prefetch the fill value.
+ OpFoldResult firstFillVal =
+ getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
+ // Collect all the outs values for the fill operations.
+ SmallVector<Value> allOuts;
+ allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
+
+ auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
+ auto fillOp = v.getDefiningOp<linalg::FillOp>();
+ if (!fillOp) {
+ return false;
+ }
+
+ OpFoldResult fillVal =
+ getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
+ if (fillVal != firstFillVal)
+ return false;
+
+ allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
+ return true;
+ };
+ if (!llvm::all_of(concatOperands.drop_front(),
+ isDefinedByCompatibleFillOp)) {
+ return rewriter.notifyMatchFailure(
+ concatOp, "not all operands are defined by a compatible fill op");
+ }
+
+ Value outsConcat = rewriter.create<tensor::ConcatOp>(
+ concatOp.getLoc(), concatOp.getDim(), allOuts);
+ rewriter.replaceOpWithNewOp<linalg::FillOp>(
+ concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
- FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
- FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
+ results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
+ FoldFillWithPack, FoldFillWithPad,
+ FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+ FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..02b46c405e2bd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1096,3 +1096,30 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
func.return %transpose2 : tensor<3x4x5xf32>
}
+
+// -----
+
+func.func @concats_of_fill(
+ %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+ -> tensor<5x?x?xf32>
+{
+ %cst0 = arith.constant 0.0 : f32
+ %cst1 = arith.constant 0.0 : f32
+ %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+ %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+ %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ return %4 : tensor<5x?x?xf32>
+}
+// CHECK: func @concats_of_fill(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index)
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+// CHECK-DAG: %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+// CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
+// CHECK: return %[[FILL]]
|
} | ||
|
||
auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>(); | ||
if (!firstFillOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where this check would then become redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow how that becomes redundant.
SmallVector<Value> allOuts; | ||
allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get()); | ||
|
||
auto isDefinedByCompatibleFillOp = [&](Value v) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like you could have used this lambda on the first value, too, instead of duplicated the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the first operand is checked to see if this is even a candidate, i.e. if the first operand is not a fill, then we early exist. For the rest of the operands we compare with the first (I find it strange comparing an op with itself...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -879,15 +879,66 @@ struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> { | |||
} | |||
}; | |||
|
|||
// Fold a concat with all elements being fills of the same value | |||
// into a fill of the concat result shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use /// in top level comment
If a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the `outs` operands. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
9b79440
to
05042ed
Compare
If a concat has all its operands as just fills, and the values match, then the fill could happen on the concatenated values of the
outs
operands.