Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Simplify nested selects with implied condition #83739

Merged
merged 5 commits into from
Mar 5, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Mar 3, 2024

This patch does the following simplification:

sel1 = select cond1, X, Y 
sel2 = select cond2, sel1, Z
-->
sel2 = select cond2, X, Z if cond2 implies cond1
sel2 = select cond2, Y, Z if cond2 implies !cond1

Alive2: https://alive2.llvm.org/ce/z/9A_arU

It cannot be done in CVP/SCCP since we should guarantee that cond2 is not an undef.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 3, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch does the following simplification:

sel1 = select cond1, X, Y 
sel2 = select cond2, sel1, Z
-->
sel2 = select cond2, X, Z if cond2 implies cond1
sel2 = select cond2, Y, Z if cond2 implies !cond1

Alive2: https://alive2.llvm.org/ce/z/9A_arU

It cannot be done in CVP/SCCP since we should guarantee that cond2 is not an undef.


Full diff: https://github.com/llvm/llvm-project/pull/83739.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+17)
  • (modified) llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll (+4-8)
  • (modified) llvm/test/Transforms/InstCombine/nested-select.ll (+84)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 71fa9b9ba41ebb..dc4347fdd713c6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3867,5 +3867,22 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  // Fold nested selects if the inner condition can be implied by the outer
+  // condition.
+  Value *InnerCondVal;
+  const DataLayout &DL = getDataLayout();
+  if (match(TrueVal,
+            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
+      CondVal->getType() == InnerCondVal->getType())
+    if (auto Implied =
+            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/true))
+      return replaceOperand(SI, 1, *Implied ? LHS : RHS);
+  if (match(FalseVal,
+            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
+      CondVal->getType() == InnerCondVal->getType())
+    if (auto Implied =
+            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/false))
+      return replaceOperand(SI, 2, *Implied ? LHS : RHS);
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
index d03e22bc4c9fbf..b5ef1f466958d7 100644
--- a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
+++ b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
@@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high
 ; Regression test for PR53252.
 define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n10_ugt_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X]], 128
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0
@@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 
 define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n11_uge_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X]], 129
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0
diff --git a/llvm/test/Transforms/InstCombine/nested-select.ll b/llvm/test/Transforms/InstCombine/nested-select.ll
index 42a0f81e7b85a2..d4bbf0ae48590a 100644
--- a/llvm/test/Transforms/InstCombine/nested-select.ll
+++ b/llvm/test/Transforms/InstCombine/nested-select.ll
@@ -498,3 +498,87 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
   %outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel
   ret i1 %outer.sel
 }
+
+define i8 @test_implied_true(i8 %x) {
+; CHECK-LABEL: @test_implied_true(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
+; CHECK-LABEL: @test_implied_true_vec(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt <2 x i8> %x, zeroinitializer
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}
+
+define i8 @test_implied_true_falseval(i8 %x) {
+; CHECK-LABEL: @test_implied_true_falseval(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp sgt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 20, i8 %sel1
+  ret i8 %sel2
+}
+
+define i8 @test_implied_false(i8 %x) {
+; CHECK-LABEL: @test_implied_false(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp sgt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+; Negative tests
+
+define i8 @test_imply_fail(i8 %x) {
+; CHECK-LABEL: @test_imply_fail(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, -10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) {
+; CHECK-LABEL: @test_imply_type_mismatch(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt i8 %y, 0
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 3, 2024
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks basically fine, but I think it may subsume some existing folds?

For example, what about this one?

// select(C, select(C, a, b), c) -> select(C, a, c)
if (TrueSI->getCondition() == CondVal) {
if (SI.getTrueValue() == TrueSI->getTrueValue())
return nullptr;
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}

There's also foldAndOrOfSelectUsingImpliedCond() -- I think your fold may be the same, but without the limitation that the top-level select is a logical one?

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 4, 2024

This looks basically fine, but I think it may subsume some existing folds?

For example, what about this one?

// select(C, select(C, a, b), c) -> select(C, a, c)
if (TrueSI->getCondition() == CondVal) {
if (SI.getTrueValue() == TrueSI->getTrueValue())
return nullptr;
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}

There's also foldAndOrOfSelectUsingImpliedCond() -- I think your fold may be the same, but without the limitation that the top-level select is a logical one?

We can further remove the following logic if we fold it recursively.

// Try to simplify a binop sandwiched between 2 selects with the same
// condition. This is not valid for div/rem because the select might be
// preventing a division-by-zero.
// TODO: A div/rem restriction is conservative; use something like
// isSafeToSpeculativelyExecute().
// select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z)
BinaryOperator *TrueBO;
if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) {
if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) {
if (TrueBOSI->getCondition() == CondVal) {
replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue());
Worklist.push(TrueBO);
return &SI;
}
}
if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) {
if (TrueBOSI->getCondition() == CondVal) {
replaceOperand(*TrueBO, 1, TrueBOSI->getTrueValue());
Worklist.push(TrueBO);
return &SI;
}
}
}
// select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W))
BinaryOperator *FalseBO;
if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) {
if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) {
if (FalseBOSI->getCondition() == CondVal) {
replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue());
Worklist.push(FalseBO);
return &SI;
}
}
if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) {
if (FalseBOSI->getCondition() == CondVal) {
replaceOperand(*FalseBO, 1, FalseBOSI->getFalseValue());
Worklist.push(FalseBO);
return &SI;
}
}
}

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 4, 2024

This looks basically fine, but I think it may subsume some existing folds?

For example, what about this one?

// select(C, select(C, a, b), c) -> select(C, a, c)
if (TrueSI->getCondition() == CondVal) {
if (SI.getTrueValue() == TrueSI->getTrueValue())
return nullptr;
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}

There's also foldAndOrOfSelectUsingImpliedCond() -- I think your fold may be the same, but without the limitation that the top-level select is a logical one?

Done.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 4, 2024
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 4, 2024

This looks basically fine, but I think it may subsume some existing folds?

For example, what about this one?

// select(C, select(C, a, b), c) -> select(C, a, c)
if (TrueSI->getCondition() == CondVal) {
if (SI.getTrueValue() == TrueSI->getTrueValue())
return nullptr;
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}

There's also foldAndOrOfSelectUsingImpliedCond() -- I think your fold may be the same, but without the limitation that the top-level select is a logical one?

The compile-time improvement looks good:
http://llvm-compile-time-tracker.com/compare.php?from=732a5cba8c739ed40a7280b5d74ca717910c2c4c&to=2e94109d115a81f271cefc7d6c0e839920c52bba&stat=instructions%3Au

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 4, 2024
@dtcxzyw dtcxzyw merged commit 0c47363 into llvm:main Mar 5, 2024
4 checks passed
@dtcxzyw dtcxzyw deleted the perf/nested-select-implied-condition branch March 5, 2024 06:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants