-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reland "[mlir][arith] Canonicalization patterns for arith.select
(#67809)"
#68941
Conversation
…lvm#67809)" This cherry-picks the changes in llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting types to i1.
@peterbell10 Somehow I can not add you to reviewers. Please take a look at it, thank you! |
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThis cherry-picks the changes in I.e., it also applies the below patch:
Full diff: https://github.com/llvm/llvm-project/pull/68941.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..817de0e06c661b9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,52 @@ def CmpIExtUI :
CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
"$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+ Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+ (SelectOp $pred, $b, $a),
+ [(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+ Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+ (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+ Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+ (SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+ Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+ (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
+def SelectAndNotCond :
+ Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+ (SelectOp (Arith_AndIOp $predA,
+ (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+ $x, $y),
+ [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
+// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
+def SelectOrCond :
+ Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
+ (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
+
+// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
+def SelectOrNotCond :
+ Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
+ (SelectOp (Arith_OrIOp $predA,
+ (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+ $x, $y),
+ [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..0ecc288f3b07701 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SelectI1Simplify, SelectToExtUI>(context);
+ results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
+ SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
+ SelectNotCond, SelectToExtUI>(context);
}
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f697f3d01458eee..1b0547c9e8f804a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
return %res : i1
}
+// CHECK-LABEL: @redundantSelectTrue
+// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+ %0 = arith.select %arg0, %arg1, %arg2 : i32
+ %res = arith.select %arg0, %0, %arg3 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+ %0 = arith.select %arg0, %arg1, %arg2 : i32
+ %res = arith.select %arg0, %arg3, %0 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
+// CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
+// CHECK-NEXT: return %[[res1]], %[[res2]]
+func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+ %one = arith.constant 1 : i1
+ %cond1 = arith.xori %arg0, %one : i1
+ %cond2 = arith.xori %one, %arg0 : i1
+
+ %res1 = arith.select %cond1, %arg1, %arg2 : i32
+ %res2 = arith.select %cond2, %arg3, %arg4 : i32
+ return %res1, %res2 : i32, i32
+}
+
+// CHECK-LABEL: @selAndCond
+// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %sel, %arg3 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selAndNotCond
+// CHECK-NEXT: %[[one:.+]] = arith.constant true
+// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %sel, %arg2 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selOrCond
+// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %arg2, %sel : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selOrNotCond
+// CHECK-NEXT: %[[one:.+]] = arith.constant true
+// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %arg3, %sel : i32
+ return %res : i32
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
thanks for the review! |
This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b and extends the pattern to handle vector types.
To reuse
getBoolAttribute
method, it moves the static method above the include of generated file.