Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold A pred C ? (A >> BW - 1) : 1 --> ZExt(A pred C ? A < 0 : 1) #69961

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

elhewaty
Copy link
Contributor

@elhewaty elhewaty commented Oct 23, 2023

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-llvm-transforms

Author: None (elhewaty)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/69961.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+14)
  • (modified) llvm/test/Transforms/InstCombine/icmp-select.ll (+12)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7a15c0dee492b5a..4fd3f4f594d7a77 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3415,6 +3415,20 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
                                 TrueVal);
   }
 
+  // select (icmp eq a, 0), 1, (lshr a, 31) -> icmp sle a, 0,
+  // which is then converted to icmp sle a, 1
+  CmpInst::Predicate Pred;
+  Value *A;
+  const APInt *C;
+  if (match(CondVal, m_Cmp(Pred, m_Value(A), m_Zero())) &&
+      match(TrueVal, m_One()) &&
+      match(FalseVal, m_LShr(m_Specific(A), m_APInt(C))) &&
+      Pred == ICmpInst::ICMP_EQ && *C == 31) {
+    auto *Cond = Builder.CreateICmpSLE(A,
+                                       ConstantInt::getNullValue(A->getType()));
+    return new ZExtInst(Cond, A->getType());
+  }
+
   if (Instruction *R = foldSelectOfBools(SI))
     return R;
 
diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll
index 0d723c9df32e2f4..7fe37200788dc5e 100644
--- a/llvm/test/Transforms/InstCombine/icmp-select.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-select.ll
@@ -5,6 +5,18 @@ declare void @use(i8)
 declare void @use.i1(i1)
 declare i8 @llvm.umin.i8(i8, i8)
 
+define i32 @test_icmp_select_lte_0(i32 %0) {
+; CHECK-LABEL: @test_icmp_select_lte_0(
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp slt i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[RE:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[RE]]
+;
+  %cml = icmp eq i32 %0, 0
+  %lshr = lshr i32 %0, 31
+  %re = select i1 %cml, i32 1, i32 %lshr
+  ret i32 %re
+}
+
 define i1 @icmp_select_const(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_select_const(
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i8 [[X:%.*]], 0

@github-actions
Copy link

github-actions bot commented Oct 23, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@elhewaty
Copy link
Contributor Author

@dtcxzyw @goldsteinn @arsenm

if (match(CondVal, m_Cmp(Pred, m_Value(A), m_Zero())) &&
match(TrueVal, m_One()) &&
match(FalseVal, m_LShr(m_Specific(A), m_APInt(C))) &&
Pred == ICmpInst::ICMP_EQ && *C == 31) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded assumption of i32 instead of bit width - 1?

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs tests with different sized integers and vectors

@XChy
Copy link
Member

XChy commented Oct 24, 2023

Similar patterns like a > 0 ? a < 0 : 1 could be considered here, proof.
I think this patch could be expressed in a more general way based on ConstantRange.

@elhewaty
Copy link
Contributor Author

elhewaty commented Oct 24, 2023

I think this patch could be expressed in a more general way based on ConstantRange.

@XChy Can you explain further?

@XChy
Copy link
Member

XChy commented Oct 24, 2023

For example, cmp1 ? cmp2 : 1 is equivalent to (cmp1 & cmp2) | (!cmp1). When the constant range of cmp1 is a subset of cmp2, it could be seen as cmp1 | !cmp1 -> 1.

However, I found that InstCombine has handled such pattern, just missing here for a < 0 being a << BW - 1.
My 2c is to transform (A pred C) ? (A << BW - 1) : 1 into ZExt ((A pred C) ? A < 0 : 1).

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @XChy on how to handle this. That will enable more combinations of predicates to fold.

This is something of a recurring pattern, see https://reviews.llvm.org/D154791 and #68244 for related patches.

@elhewaty
Copy link
Contributor Author

My 2c is to transform (A pred C) ? (A << BW - 1) : 1 into ZExt ((A pred C) ? A < 0 : 1).

@XChy Should first transform A pred C ? (A << BW - 1) : 1 --> Zext(A pred C ? A < 0 : 1), and then
to fold Zext(A pred C ? A < 0: 1), by similar patterns like this pull

@XChy
Copy link
Member

XChy commented Oct 24, 2023

Just A pred C ? (A << BW - 1) : 1 --> Zext(A pred C ? A < 0 : 1). The latter form would be automatically folded somewhere else.

@elhewaty
Copy link
Contributor Author

Then the optimization in this patch is handled somewhere?

@XChy
Copy link
Member

XChy commented Oct 24, 2023

Then the optimization in this patch is handled somewhere?

Yes. What we need to do is just making A pred C ? (A << BW - 1) : 1 -> ZExt(A pred C ? A < 0 : 1). This form is easier to fold in other optimization.

@elhewaty
Copy link
Contributor Author

Then the optimization in this patch is handled somewhere?

I mean Is there a patch already for this?

@XChy
Copy link
Member

XChy commented Oct 24, 2023

I mean Is there a patch already for this?

Things like A pred1 C1 ? A pred2 C2 : 1 are folded by previous patches. But no one already for A pred C ? (A << BW - 1) : 1.

Just fold A pred C ? (A << BW - 1) : 1 -> ZExt(A pred C ? A < 0 : 1), which is foldable for previous patches.
Does this explanation make sense to you?

@elhewaty
Copy link
Contributor Author

Yes, It makes, thanks.

Will this a ? (a < 0) : 1 --> (a <= 0) be folded?

@XChy
Copy link
Member

XChy commented Oct 24, 2023

Yes, godbolt.

@elhewaty
Copy link
Contributor Author

@XChy What is wrong with this approach?

CmpInst::Predicate Pred;
  Value *A;
  ConstantInt *C1, *C2;
  const APInt *C3;
  if (match(CondVal, m_Cmp(Pred, m_Value(A), m_ConstantInt(C1))) &&
      match(TrueVal, m_LShr(m_Specific(A), m_ConstantInt(C2))) &&
      match(FalseVal, m_APInt(C3)) && C3->isOne() &&
      C2->getValue() == C2->getValue().getBitWidth() - 1) {
    auto *Cond = Builder.CreateICmp(Pred, A, C1);
    auto *LArm = Builder.CreateICmpSLE(A, ConstantInt::getNullValue(A->getType()));
    Constant *RArm = ConstantInt::get(SelType, *C3);
    return SelectInst::Create(Cond, LArm, RArm);
  }

@XChy
Copy link
Member

XChy commented Oct 25, 2023

@XChy What is wrong with this approach?

CmpInst::Predicate Pred;
  Value *A;
  ConstantInt *C1, *C2;
  const APInt *C3;
  if (match(CondVal, m_Cmp(Pred, m_Value(A), m_ConstantInt(C1))) &&
      match(TrueVal, m_LShr(m_Specific(A), m_ConstantInt(C2))) &&
      match(FalseVal, m_APInt(C3)) && C3->isOne() &&
      C2->getValue() == C2->getValue().getBitWidth() - 1) {
    auto *Cond = Builder.CreateICmp(Pred, A, C1);
    auto *LArm = Builder.CreateICmpSLE(A, ConstantInt::getNullValue(A->getType()));
    Constant *RArm = ConstantInt::get(SelType, *C3);
    return SelectInst::Create(Cond, LArm, RArm);
  }

m_Cmp -> m_ICmp
m_ConstantInt -> m_SpecificInt(getScalarSizeInBits() - 1)
m_APInt(C3) -> m_One()
Builder.CreateICmpSLE -> Builder.CreateICmpSLT
RArm -> C3

You may need to refer to https://reviews.llvm.org/D154791 and #68244.

@XChy
Copy link
Member

XChy commented Oct 25, 2023

return SelectInst::Create(Cond, LArm, RArm);

And a ZExt instruction should be emitted here, surrounding the whole select.

@elhewaty
Copy link
Contributor Author

@XChy I tried the following approach.

CmpInst::Predicate Pred;
  Value *A;
  if (match(CondVal,
            m_ICmp(Pred, m_Value(A),
                   m_SpecificInt(A->getType()->getScalarSizeInBits() - 1))) &&
      match(TrueVal,
            m_LShr(m_Specific(A),
                   m_SpecificInt(A->getType()->getScalarSizeInBits() - 1))) &&
      match(FalseVal, m_One())) {
    auto *NewTrue = Builder.CreateICmpSLE(A,
                                          Constant::getNullValue(A->getType()));
    replaceOperand(SI, 1, NewTrue);
    return new ZExtInst(&SI, SelType);
  }

but I got this,

Stack dump:
0.	Program arguments: /media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt -S -passes=instcombine
 #0 0x0000556c41181420 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x2cd0420)
 #1 0x0000556c4117e82f llvm::sys::RunSignalHandlers() (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x2ccd82f)
 #2 0x0000556c4117e985 SignalHandler(int) Signals.cpp:0:0
 #3 0x00007fd378e42520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x0000556c40d70b5d llvm::InstCombinerImpl::visitSelectInst(llvm::SelectInst&) InstCombineSelect.cpp:0:0
 #5 0x0000556c40c94cec llvm::InstCombinerImpl::run() InstructionCombining.cpp:0:0
 #6 0x0000556c40c963ad combineInstructionsOverFunction(llvm::Function&, llvm::InstructionWorklist&, llvm::AAResults*, llvm::AssumptionCache&, llvm::TargetLibraryInfo&, llvm::TargetTransformInfo&, llvm::DominatorTree&, llvm::OptimizationRemarkEmitter&, llvm::BlockFrequencyInfo*, llvm::ProfileSummaryInfo*, unsigned int, bool, llvm::LoopInfo*) InstructionCombining.cpp:0:0
 #7 0x0000556c40c97361 llvm::InstCombinePass::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x27e6361)
 #8 0x0000556c4138dc66 llvm::detail::PassModel<llvm::Function, llvm::InstCombinePass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) PassBuilder.cpp:0:0
 #9 0x0000556c40b66ce1 llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x26b5ce1)
#10 0x0000556c413866a6 llvm::detail::PassModel<llvm::Function, llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) PassBuilder.cpp:0:0
#11 0x0000556c40b659db llvm::ModuleToFunctionPassAdaptor::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x26b49db)
#12 0x0000556c4138d8f6 llvm::detail::PassModel<llvm::Module, llvm::ModuleToFunctionPassAdaptor, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) PassBuilder.cpp:0:0
#13 0x0000556c40b637c1 llvm::PassManager<llvm::Module, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x26b27c1)
#14 0x0000556c3fae8b35 llvm::runPassPipeline(llvm::StringRef, llvm::Module&, llvm::TargetMachine*, llvm::TargetLibraryInfoImpl*, llvm::ToolOutputFile*, llvm::ToolOutputFile*, llvm::ToolOutputFile*, llvm::StringRef, llvm::ArrayRef<llvm::PassPlugin>, llvm::opt_tool::OutputKind, llvm::opt_tool::VerifierKind, bool, bool, bool, bool, bool, bool, bool) (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x1637b35)
#15 0x0000556c3faf809f main (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x164709f)
#16 0x00007fd378e29d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#17 0x00007fd378e29e40 call_init ./csu/../csu/libc-start.c:128:20
#18 0x00007fd378e29e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#19 0x0000556c3fadb325 _start (/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt+0x162a325)
Segmentation fault (core dumped)
Traceback (most recent call last):
  File "/media/mohamed/Local-Disk/open-source/New-LLVM/llvm-project/llvm/utils/update_test_checks.py", line 327, in <module>
    main()
  File "/media/mohamed/Local-Disk/open-source/New-LLVM/llvm-project/llvm/utils/update_test_checks.py", line 154, in main
    raw_tool_output = common.invoke_tool(
  File "/media/mohamed/Local-Disk/open-source/New-LLVM/llvm-project/llvm/utils/UpdateTestChecks/common.py", line 453, in invoke_tool
    stdout = subprocess.check_output(
  File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '/media/mohamed/Local-Disk/open-source/New-LLVM/build/bin/opt -S -passes=instcombine' returned non-zero exit status 139.
```

@XChy
Copy link
Member

XChy commented Nov 14, 2023

Don't reuse SI when replacing it with new instruction, which causes self-reference of instruction.
Please apply Builder.CreateSelect(CondVal, SLTZero, One) to create a new select instruction, instead of reuse SI by replaceOperand(SI, 1, NewTrue);..

@elhewaty
Copy link
Contributor Author

What's wrong with this?

CmpInst::Predicate Pred;
  Value *A;
  if (match(CondVal,
            m_ICmp(Pred, m_Value(A),
                   m_SpecificInt(A->getType()->getScalarSizeInBits() - 1))) &&
      match(TrueVal,
            m_LShr(m_Specific(A),
                   m_SpecificInt(A->getType()->getScalarSizeInBits() - 1))) &&
      match(FalseVal, m_One())) {
    Constant *One = ConstantInt::get(SelType, 1);
    auto *SLTZero
      = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(A->getType()));
    auto *Sel = Builder.CreateSelect(CondVal, SLTZero, One);
    Type *Ty = A->getType();
    auto *Zext = Builder.CreateZExt(Sel, Ty);
    return replaceInstUsesWith(SI, Zext);
  }

@nikic
Copy link
Contributor

nikic commented Nov 14, 2023

@elhewaty I think the type of your One is wrong. It should be the same type as SLTZero (and i1 or i1 vector), not the same type as the original select.

@elhewaty
Copy link
Contributor Author

@dtcxzyw, Can you please review this, too?

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jan 31, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 31, 2024

@dtcxzyw, Can you please review this, too?

Could you please rebase this patch first? I cannot apply it to my local repo:(

@elhewaty
Copy link
Contributor Author

@dtcxzyw, done

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jan 31, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 31, 2024

@k-arrows @elhewaty Does this pattern exist in some real-world applications?

@elhewaty
Copy link
Contributor Author

@k-arrows @elhewaty Does this pattern exist in some real-world applications?

@dtcxzyw, I am not sure, but I think @XChy is the best to answer this, as he suggested to change
the pattern from a ? (a < 0) : 1 --> (a <= 0) to A pred C ? (A >> BW - 1) : 1 -> ZExt(A pred C ? A < 0 : 1).

and here @nikic said that he approved this pattern

@k-arrows
Copy link

k-arrows commented Jan 31, 2024

Does this pattern exist in some real-world applications?

Probably no, so I won't insist on this pattern if there is no particular benefit. I remember clang was sometimes suboptimal for expressions which contain (a < 0). (See #67916 #63751 #62586 for example.) This fold was probably found around the same time as these.

By the way, shouldn't we at least change the title of this PR? Now we're dealing with more general pattern.

@elhewaty elhewaty changed the title [InstCombine] Fold selection between less than zero and one [InstCombine] Fold A pred C ? (A >> BW - 1) : 1 --> ZExt(A pred C ? A < 0 : 1) Jan 31, 2024
@elhewaty
Copy link
Contributor Author

elhewaty commented Feb 3, 2024

@nikic, @dtcxzyw ping

@elhewaty
Copy link
Contributor Author

@dtcxzyw, @XChy , @nikic ping.

@@ -3409,6 +3409,25 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 2, S);
}

{
// A pred C ? (A >> BW - 1) : 1 --> ZExt(A pred C ? A < 0 : 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be generalized to Cond ? (A >> (BW - 1)) : 1 --> ZExt(Cond ? A < 0 : 1).

I am sorry I cannot give the approval for this patch if you cannot demonstrate that it benefits some real-world applications.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT the idea is that this is useful if A pred C implies something about A < 0 in which case it makes sense to fold A >> BW - 1 -> the more easy to reason about A < 0.
At least thats what it looks like the purpose is from the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn, Should I fold A < 0 instead of A Pred C? or draft the patch until we can demonstrate that it benefits some real-world applications

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real comment about the real-world impact. I'm generally pretty in favor of "get anything we can in" but think thats a minority opinion.

I would refactor to make the goal explicit. Use makeExactICmpRegion on the A Pred C and just do the complete simplification if the range simplifies A < 0 to a constant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] a ? (a < 0) : 1 --> (a <= 0)
8 participants