Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland "[mlir][arith] Canonicalization patterns for arith.select (#67809)" #68941

Merged
merged 3 commits into from
Oct 13, 2023

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Oct 13, 2023

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b and extends the pattern to handle vector types.

To reuse getBoolAttribute method, it moves the static method above the include of generated file.

…lvm#67809)"

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting
types to i1.
@hanhanW
Copy link
Contributor Author

hanhanW commented Oct 13, 2023

@peterbell10 Somehow I can not add you to reviewers. Please take a look at it, thank you!

@llvmbot
Copy link

llvmbot commented Oct 13, 2023

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting types to i1.

I.e., it also applies the below patch:

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 9d38513215d3..817de0e06c66 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -263,7 +263,8 @@ def SelectAndNotCond :
     Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
         (SelectOp (Arith_AndIOp $predA,
                                 (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
-                  $x, $y)>;
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
 
 // select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
 def SelectOrCond :
@@ -275,7 +276,8 @@ def SelectOrNotCond :
     Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
         (SelectOp (Arith_OrIOp $predA,
                                (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
-                  $x, $y)>;
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
 
 //===----------------------------------------------------------------------===//
 // IndexCastOp

Full diff: https://github.com/llvm/llvm-project/pull/68941.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+46)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-1)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+76)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..817de0e06c661b9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,52 @@ def CmpIExtUI :
             CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
                   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+    Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+        (SelectOp $pred, $b, $a),
+        [(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+    Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+        (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+    Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+        (SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+        (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
+def SelectAndNotCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+        (SelectOp (Arith_AndIOp $predA,
+                                (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
+// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
+def SelectOrCond :
+    Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
+        (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
+
+// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
+def SelectOrNotCond :
+    Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
+        (SelectOp (Arith_OrIOp $predA,
+                               (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..0ecc288f3b07701 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SelectI1Simplify, SelectToExtUI>(context);
+  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
+              SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
+              SelectNotCond, SelectToExtUI>(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f697f3d01458eee..1b0547c9e8f804a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @redundantSelectTrue
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %0, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %arg3, %0 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+//       CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
+//       CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
+//       CHECK-NEXT: return %[[res1]], %[[res2]]
+func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+  %one = arith.constant 1 : i1
+  %cond1 = arith.xori %arg0, %one : i1
+  %cond2 = arith.xori %one, %arg0 : i1
+
+  %res1 = arith.select %cond1, %arg1, %arg2 : i32
+  %res2 = arith.select %cond2, %arg3, %arg4 : i32
+  return %res1, %res2 : i32, i32
+}
+
+// CHECK-LABEL: @selAndCond
+//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %sel, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selAndNotCond
+//       CHECK-NEXT: %[[one:.+]] = arith.constant true
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %sel, %arg2 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selOrCond
+//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %arg2, %sel : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selOrNotCond
+//       CHECK-NEXT: %[[one:.+]] = arith.constant true
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %arg3, %sel : i32
+  return %res : i32
+}
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true

@hanhanW
Copy link
Contributor Author

hanhanW commented Oct 13, 2023

thanks for the review!

@hanhanW hanhanW merged commit 6dbc6df into llvm:main Oct 13, 2023
2 of 3 checks passed
@hanhanW hanhanW deleted the arith-fix branch October 13, 2023 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants