-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][arith] Fix arith.select
canonicalization patterns
#84685
[mlir][arith] Fix arith.select
canonicalization patterns
#84685
Conversation
Because `arith.select` does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns were incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case of `i1` select with constants. Patterns that are removed: * select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) * select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) * select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) * select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) * arith.select %arg, %x, %y : i1 => and(%arg, %x) or and(!%arg, %y) The first two patterns are incorrect when `predB` is poison and `predA` is false, as a non-poison `y` gets compiled to `poison`. The next two patterns are incorrect when `predB` is poison and `predA` is true, as a non-poison `x` gets compiled to `poison`. The last pattern is incorrect as it propagates poison from all operands afer compilation.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Fehr Mathieu (math-fehr) ChangesBecause Patterns that are removed:
Pattern that is added:
The first two patterns are incorrect when Full diff: https://github.com/llvm/llvm-project/pull/84685.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 11c4a29718e1d9..caca2ff81964f7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -253,9 +253,6 @@ def CmpIExtUI :
// SelectOp
//===----------------------------------------------------------------------===//
-def GetScalarOrVectorTrueAttribute :
- NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
-
// select(not(pred), a, b) => select(pred, b, a)
def SelectNotCond :
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -272,31 +269,12 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
-// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
-def SelectAndCond :
- Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
- (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
-
-// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
-def SelectAndNotCond :
- Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
- (SelectOp (Arith_AndIOp $predA,
- (Arith_XOrIOp $predB,
- (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
- $x, $y)>;
-
-// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
-def SelectOrCond :
- Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
- (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
-
-// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
-def SelectOrNotCond :
- Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
- (SelectOp (Arith_OrIOp $predA,
- (Arith_XOrIOp $predB,
- (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
- $x, $y)>;
+// select(pred, false, true) => not(pred)
+def SelectI1ToNot :
+ Pat<(SelectOp $pred,
+ (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
+ (ConstantLikeMatcher ConstantAttr<I1Attr, "1">)),
+ (Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))>;
//===----------------------------------------------------------------------===//
// IndexCastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..9f64a07f31e3af 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
-
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
@@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// SelectOp
//===----------------------------------------------------------------------===//
-// Transforms a select of a boolean to arithmetic operations
-//
-// arith.select %arg, %x, %y : i1
-//
-// becomes
-//
-// and(%arg, %x) or and(!%arg, %y)
-struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
- using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arith::SelectOp op,
- PatternRewriter &rewriter) const override {
- if (!op.getType().isInteger(1))
- return failure();
-
- Value falseConstant =
- rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
- Value notCondition = rewriter.create<arith::XOrIOp>(
- op.getLoc(), op.getCondition(), falseConstant);
-
- Value trueVal = rewriter.create<arith::AndIOp>(
- op.getLoc(), op.getCondition(), op.getTrueValue());
- Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
- op.getFalseValue());
- rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
- return success();
- }
-};
-
// select %arg, %c1, %c0 => extui %arg
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
@@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
- SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
- SelectNotCond, SelectToExtUI>(context);
+ results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
+ SelectI1ToNot, SelectToExtUI>(context);
}
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cb98a10048a309..bdc6c91d926775 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 {
return %res : i1
}
-// CHECK-LABEL: @selToArith
-// CHECK-NEXT: %[[trueval:.+]] = arith.constant true
-// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
-// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
-// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
-// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
-// CHECK: return %[[res]]
-func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
- %res = arith.select %arg0, %arg1, %arg2 : i1
- return %res : i1
-}
-
// CHECK-LABEL: @redundantSelectTrue
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
// CHECK-NEXT: return %[[res]]
@@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
-// CHECK-LABEL: @selAndCond
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
-// CHECK-NEXT: return %[[res]]
-func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %sel, %arg3 : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCond
-// CHECK-NEXT: %[[one:.+]] = arith.constant true
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %sel, %arg2 : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCondVec
-// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
- %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
- %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
- return %res : vector<4xi32>
-}
-
-// CHECK-LABEL: @selOrCond
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
-// CHECK-NEXT: return %[[res]]
-func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %arg2, %sel : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCond
-// CHECK-NEXT: %[[one:.+]] = arith.constant true
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
- %sel = arith.select %arg0, %arg2, %arg3 : i32
- %res = arith.select %arg1, %arg3, %sel : i32
- return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCondVec
-// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-// CHECK-NEXT: return %[[res]]
-func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
- %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
- %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
- return %res : vector<4xi32>
-}
-
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat, thanks for taking care of these.
Because
arith.select
does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns are currently incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case ofi1
select with constants.Patterns that are removed:
Pattern that is added:
The first two patterns are incorrect when
predB
is poison andpredA
is false, as a non-poisony
gets compiled topoison
. The next two patterns are incorrect whenpredB
is poison andpredA
is true, as a non-poisonx
gets compiled topoison
. The last pattern is incorrect as it propagates poison from all operands afer compilation.