Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JumpThreading] Convert s/zext i1 to select i1 for further unfolding #89345

Closed
wants to merge 1 commit into from

Conversation

FLZ101
Copy link
Contributor

@FLZ101 FLZ101 commented Apr 19, 2024

Convert a s/zext i1 (e.g. %b = zext i1 %a to i32) to a select i1 (e.g. %b = select i1 %a, i32 1, i32 0) if doing so helps jump threading.

See llvm/test/Transforms/JumpThreading/szext.ll for an example.

Convert "sext/zext i1" into "select i1" if it could help
jump threading
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Franklin Zhang (FLZ101)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/89345.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Scalar/JumpThreading.h (+2)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+81-1)
  • (added) llvm/test/Transforms/JumpThreading/szext.ll (+94)
diff --git a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
index 3364d7eaee4247..c7da9053b7abeb 100644
--- a/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
+++ b/llvm/include/llvm/Transforms/Scalar/JumpThreading.h
@@ -162,6 +162,8 @@ class JumpThreadingPass : public PassInfoMixin<JumpThreadingPass> {
   bool tryToUnfoldSelect(SwitchInst *SI, BasicBlock *BB);
   bool tryToUnfoldSelectInCurrBB(BasicBlock *BB);
 
+  bool tryToConvertSZExtToSelect(BasicBlock *BB);
+
   bool processGuards(BasicBlock *BB);
   bool threadGuard(BasicBlock *BB, IntrinsicInst *Guard, BranchInst *BI);
 
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index ffcb511e6a8312..c92519610aec18 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -971,6 +971,9 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) {
   if (maybeMergeBasicBlockIntoOnlyPred(BB))
     return true;
 
+  if (tryToConvertSZExtToSelect(BB))
+    return true;
+
   if (tryToUnfoldSelectInCurrBB(BB))
     return true;
 
@@ -2750,7 +2753,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
 // Pred is a predecessor of BB with an unconditional branch to BB. SI is
 // a Select instruction in Pred. BB has other predecessors and SI is used in
 // a PHI node in BB. SI has no other use.
-// A new basic block, NewBB, is created and SI is converted to compare and 
+// A new basic block, NewBB, is created and SI is converted to compare and
 // conditional branch. SI is erased from parent.
 void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
                                           SelectInst *SI, PHINode *SIUse,
@@ -2997,6 +3000,83 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) {
   return false;
 }
 
+/// Try to convert "sext/zext i1" into "select i1" which could be further
+/// unfolded by tryToUnfoldSelect().
+///
+/// For example,
+///
+/// ; before the transformation
+/// BB1:
+///   %a = icmp ...
+///   %b = zext i1 %a to i32
+///   br label %BB2
+/// BB2:
+///   %c = phi i32 [ %b, %BB1 ], ...
+///   %d = icmp eq i32 %c, 0
+///   br i1 %d, ...
+///
+/// ------
+///
+/// ; after the transformation
+/// BB1:
+///   %a = icmp ...
+///   %b = select i1 %a, i32 1, i32 0
+///   br label %BB2
+/// BB2:
+///   %c = phi i32 [ %b, %BB1 ], ...
+///   %d = icmp eq i32 %c, 0
+///   br i1 %d, ...
+///
+bool JumpThreadingPass::tryToConvertSZExtToSelect(BasicBlock *BB) {
+  // tryToUnfoldSelect requires that Br is unconditional
+  BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator());
+  if (!Br || Br->isConditional())
+    return false;
+  BasicBlock *BBX = Br->getSuccessor(0);
+
+  SmallVector<Instruction *> ToConvert;
+  for (auto &I : *BB) {
+    using namespace PatternMatch;
+
+    Value *V;
+    if (!match(&I, m_ZExtOrSExt(m_Value(V))) || !V->getType()->isIntegerTy(1))
+      continue;
+
+    // I is only used by Phi
+    Use *U = I.getSingleUndroppableUse();
+    if (!U)
+      continue;
+    PHINode *Phi = dyn_cast<PHINode>(U->getUser());
+    if (!Phi || Phi->getParent() != BBX)
+      continue;
+
+    // tryToUnfoldSelect requires that Phi is used in the following way
+    ICmpInst::Predicate Pred;
+    if (!match(BBX->getTerminator(),
+               m_Br(m_ICmp(Pred, m_Specific(Phi), m_ConstantInt()),
+                    m_BasicBlock(), m_BasicBlock())))
+      continue;
+
+    ToConvert.push_back(&I);
+  }
+  if (ToConvert.empty())
+    return false;
+
+  LLVM_DEBUG(dbgs() << "\nconvert-szext-to-select:\n" << *BB << "\n");
+  for (Instruction *I : ToConvert) {
+    auto Ty = I->getType();
+    Value *V1 = isa<SExtInst>(I) ? ConstantInt::getAllOnesValue(Ty)
+                                 : ConstantInt::get(Ty, 1);
+    Value *V2 = ConstantInt::getNullValue(Ty);
+    SelectInst *SI =
+        SelectInst::Create(I->getOperand(0), V1, V2, I->getName(), I);
+    I->replaceAllUsesWith(SI);
+    I->eraseFromParent();
+  }
+  LLVM_DEBUG(dbgs() << *BB << "\n");
+  return true;
+}
+
 /// Try to propagate a guard from the current BB into one of its predecessors
 /// in case if another branch of execution implies that the condition of this
 /// guard is always true. Currently we only process the simplest case that
diff --git a/llvm/test/Transforms/JumpThreading/szext.ll b/llvm/test/Transforms/JumpThreading/szext.ll
new file mode 100644
index 00000000000000..290fe7ad0ca257
--- /dev/null
+++ b/llvm/test/Transforms/JumpThreading/szext.ll
@@ -0,0 +1,94 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=jump-threading < %s | FileCheck %s
+
+; void fun(int);
+;
+; int compare1(int a, int b, int c, int d)
+; {
+;     return a < b ? -1 :
+;            a > b ?  1 :
+;            c < d ? -1 :
+;            c > d ?  1 : 0;
+; }
+;
+; void test1(int a, int b, int c, int d) {
+;   int x = compare1(a, b, c, d);
+;   if (x < 0)
+;     fun(10);
+;   else if (x > 0)
+;     fun(20);
+;   else
+;     fun(30);
+; }
+
+declare void @fun(i32 noundef)
+
+define void @test1(i32 noundef %a, i32 noundef %b, i32 noundef %c, i32 noundef %d) {
+; CHECK-LABEL: define void @test1(
+; CHECK-SAME: i32 noundef [[A:%.*]], i32 noundef [[B:%.*]], i32 noundef [[C:%.*]], i32 noundef [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp slt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[IF_THEN:%.*]], label [[COND_FALSE_I:%.*]]
+; CHECK:       cond.false.i:
+; CHECK-NEXT:    [[CMP1_I:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT:    br i1 [[CMP1_I]], label [[IF_THEN2:%.*]], label [[COND_FALSE3_I:%.*]]
+; CHECK:       cond.false3.i:
+; CHECK-NEXT:    [[CMP4_I:%.*]] = icmp slt i32 [[C]], [[D]]
+; CHECK-NEXT:    br i1 [[CMP4_I]], label [[IF_THEN]], label [[COND_FALSE6_I:%.*]]
+; CHECK:       cond.false6.i:
+; CHECK-NEXT:    [[CMP7_I:%.*]] = icmp sgt i32 [[C]], [[D]]
+; CHECK-NEXT:    br i1 [[CMP7_I]], label [[IF_THEN2]], label [[IF_ELSE3:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    call void @fun(i32 noundef 10)
+; CHECK-NEXT:    br label [[IF_END4:%.*]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    call void @fun(i32 noundef 20)
+; CHECK-NEXT:    br label [[IF_END4]]
+; CHECK:       if.else3:
+; CHECK-NEXT:    [[COND12_I:%.*]] = phi i32 [ 0, [[COND_FALSE6_I]] ]
+; CHECK-NEXT:    call void @fun(i32 noundef 30)
+; CHECK-NEXT:    br label [[IF_END4]]
+; CHECK:       if.end4:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp.i = icmp slt i32 %a, %b
+  br i1 %cmp.i, label %compare1.exit, label %cond.false.i
+
+cond.false.i:                                     ; preds = %entry
+  %cmp1.i = icmp sgt i32 %a, %b
+  br i1 %cmp1.i, label %compare1.exit, label %cond.false3.i
+
+cond.false3.i:                                    ; preds = %cond.false.i
+  %cmp4.i = icmp slt i32 %c, %d
+  br i1 %cmp4.i, label %compare1.exit, label %cond.false6.i
+
+cond.false6.i:                                    ; preds = %cond.false3.i
+  %cmp7.i = icmp sgt i32 %c, %d
+  %cond.i = zext i1 %cmp7.i to i32
+  br label %compare1.exit
+
+compare1.exit:                                    ; preds = %entry, %cond.false.i, %cond.false3.i, %cond.false6.i
+  %cond12.i = phi i32 [ -1, %entry ], [ 1, %cond.false.i ], [ %cond.i, %cond.false6.i ], [ -1, %cond.false3.i ]
+  %cmp = icmp slt i32 %cond12.i, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %compare1.exit
+  call void @fun(i32 noundef 10)
+  br label %if.end4
+
+if.else:                                          ; preds = %compare1.exit
+  %cmp1.not = icmp eq i32 %cond12.i, 0
+  br i1 %cmp1.not, label %if.else3, label %if.then2
+
+if.then2:                                         ; preds = %if.else
+  call void @fun(i32 noundef 20)
+  br label %if.end4
+
+if.else3:                                         ; preds = %if.else
+  call void @fun(i32 noundef 30)
+  br label %if.end4
+
+if.end4:                                          ; preds = %if.then2, %if.else3, %if.then
+  ret void
+}

@dtcxzyw dtcxzyw requested a review from nikic April 19, 2024 06:26
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtcxzyw Can you please test this? I'm not sure whether this does more damage than good...

continue;

// I is only used by Phi
Use *U = I.getSingleUndroppableUse();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use hasOneUse() and use_begin() here. You can't use this API like this.

SelectInst *SI =
SelectInst::Create(I->getOperand(0), V1, V2, I->getName(), I);
I->replaceAllUsesWith(SI);
I->eraseFromParent();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use make_early_inc_range, I think you can merge both loops.


cond.false6.i: ; preds = %cond.false3.i
%cmp7.i = icmp sgt i32 %c, %d
%cond.i = zext i1 %cmp7.i to i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for sext?

This test does not look very minimal.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Apr 19, 2024
@FLZ101
Copy link
Contributor Author

FLZ101 commented Apr 19, 2024

Thanks for your comments.

I found a better way to implement this feature and will create a new pull request soon.

@FLZ101 FLZ101 closed this Apr 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants