Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][arith] Fix arith.select canonicalization patterns #84685

Merged
merged 2 commits into from
Mar 13, 2024

Conversation

math-fehr
Copy link
Contributor

Because arith.select does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns are currently incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case of i1 select with constants.

Patterns that are removed:

  • select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
  • select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
  • select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
  • select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
  • arith.select %arg, %x, %y : i1 => and(%arg, %x) or and(!%arg, %y)

Pattern that is added:

  • select(pred, false, true) => not(pred) for i1

The first two patterns are incorrect when predB is poison and predA is false, as a non-poison y gets compiled to poison. The next two patterns are incorrect when predB is poison and predA is true, as a non-poison x gets compiled to poison. The last pattern is incorrect as it propagates poison from all operands afer compilation.

Because `arith.select` does not propagate poison of the second
or third operand depending on the condition, some canonicalization
patterns were incorrect. This patch removes these incorrect patterns,
and adds a new pattern to fix the case of `i1` select with constants.

Patterns that are removed:
  * select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
  * select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
  * select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
  * select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
  * arith.select %arg, %x, %y : i1        => and(%arg, %x) or and(!%arg, %y)

The first two patterns are incorrect when `predB` is poison and `predA`
is false, as a non-poison `y` gets compiled to `poison`. The next two
patterns are incorrect when `predB` is poison and `predA` is true, as
a non-poison `x` gets compiled to `poison`. The last pattern is incorrect
as it propagates poison from all operands afer compilation.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 10, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Fehr Mathieu (math-fehr)

Changes

Because arith.select does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns are currently incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case of i1 select with constants.

Patterns that are removed:

  • select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
  • select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
  • select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
  • select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
  • arith.select %arg, %x, %y : i1 => and(%arg, %x) or and(!%arg, %y)

Pattern that is added:

  • select(pred, false, true) => not(pred) for i1

The first two patterns are incorrect when predB is poison and predA is false, as a non-poison y gets compiled to poison. The next two patterns are incorrect when predB is poison and predA is true, as a non-poison x gets compiled to poison. The last pattern is incorrect as it propagates poison from all operands afer compilation.


Full diff: https://github.com/llvm/llvm-project/pull/84685.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+6-28)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+2-33)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-80)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 11c4a29718e1d9..caca2ff81964f7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -253,9 +253,6 @@ def CmpIExtUI :
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-def GetScalarOrVectorTrueAttribute :
-  NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
-
 // select(not(pred), a, b) => select(pred, b, a)
 def SelectNotCond :
     Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -272,31 +269,12 @@ def RedundantSelectFalse :
     Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
         (SelectOp $pred, $a, $c)>;
 
-// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
-def SelectAndCond :
-    Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
-        (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
-
-// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
-def SelectAndNotCond :
-    Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
-        (SelectOp (Arith_AndIOp $predA,
-                                (Arith_XOrIOp $predB,
-                                (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
-                  $x, $y)>;
-
-// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
-def SelectOrCond :
-    Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
-        (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
-
-// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
-def SelectOrNotCond :
-    Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
-        (SelectOp (Arith_OrIOp $predA,
-                               (Arith_XOrIOp $predB,
-                               (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
-                  $x, $y)>;
+// select(pred, false, true) => not(pred) 
+def SelectI1ToNot :
+    Pat<(SelectOp $pred,
+                  (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
+                  (ConstantLikeMatcher ConstantAttr<I1Attr, "1">)),
+        (Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))>;
 
 //===----------------------------------------------------------------------===//
 // IndexCastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..9f64a07f31e3af 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
-
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-// Transforms a select of a boolean to arithmetic operations
-//
-//  arith.select %arg, %x, %y : i1
-//
-//  becomes
-//
-//  and(%arg, %x) or and(!%arg, %y)
-struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
-  using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arith::SelectOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!op.getType().isInteger(1))
-      return failure();
-
-    Value falseConstant =
-        rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
-    Value notCondition = rewriter.create<arith::XOrIOp>(
-        op.getLoc(), op.getCondition(), falseConstant);
-
-    Value trueVal = rewriter.create<arith::AndIOp>(
-        op.getLoc(), op.getCondition(), op.getTrueValue());
-    Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
-                                                    op.getFalseValue());
-    rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
-    return success();
-  }
-};
-
 //  select %arg, %c1, %c0 => extui %arg
 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
@@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
-              SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
-              SelectNotCond, SelectToExtUI>(context);
+  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
+              SelectI1ToNot, SelectToExtUI>(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cb98a10048a309..bdc6c91d926775 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 {
   return %res : i1
 }
 
-// CHECK-LABEL: @selToArith
-//       CHECK-NEXT:       %[[trueval:.+]] = arith.constant true
-//       CHECK-NEXT:       %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
-//       CHECK-NEXT:       %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
-//       CHECK-NEXT:       %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
-//       CHECK-NEXT:       %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
-//       CHECK:   return %[[res]]
-func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
-  %res = arith.select %arg0, %arg1, %arg2 : i1
-  return %res : i1
-}
-
 // CHECK-LABEL: @redundantSelectTrue
 //       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
 //       CHECK-NEXT: return %[[res]]
@@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
   return %res1, %res2 : i32, i32
 }
 
-// CHECK-LABEL: @selAndCond
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %sel, %arg3 : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCond
-//       CHECK-NEXT: %[[one:.+]] = arith.constant true
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %sel, %arg2 : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCondVec
-//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
-  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
-  %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
-  return %res : vector<4xi32>
-}
-
-// CHECK-LABEL: @selOrCond
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %arg2, %sel : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCond
-//       CHECK-NEXT: %[[one:.+]] = arith.constant true
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %arg3, %sel : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCondVec
-//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
-  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
-  %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
-  return %res : vector<4xi32>
-}
-
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat, thanks for taking care of these.

@math-fehr math-fehr merged commit 7bdba95 into llvm:main Mar 13, 2024
4 checks passed
@math-fehr math-fehr deleted the fehr/fix-select-arith-canonicalize branch March 13, 2024 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants