Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b) while Binop is commutative. #75765

Merged
merged 16 commits into from
Dec 21, 2023

Conversation

sun-jacobi
Copy link
Member

@sun-jacobi sun-jacobi commented Dec 18, 2023

This patch closes #73905

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Chia (sun-jacobi)

Changes

This PR closes #73905


Full diff: https://github.com/llvm/llvm-project/pull/75765.diff

7 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+9)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+22)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+13)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+6)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+44)
  • (added) llvm/test/Transforms/InstCombine/commutative-operation-over-phis.ll (+360)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 719a2678fc189a..cef4cc5b1d52aa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1505,6 +1505,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return Sub;
   }
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, LHS, RHS))
+    return replaceInstUsesWith(I, V);
+
   // A + -B  -->  A - B
   if (match(RHS, m_Neg(m_Value(B))))
     return BinaryOperator::CreateSub(LHS, B);
@@ -1909,6 +1912,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS))
     return replaceInstUsesWith(I, V);
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, LHS, RHS))
+    return replaceInstUsesWith(I, V);
+
   if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
     if (Instruction *F = factorizeFAddFSub(I, Builder))
       return F;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 31db1d3164b772..0398663f2c2794 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2202,6 +2202,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
+    return replaceInstUsesWith(I, V);
+
   Value *X, *Y;
   if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) &&
       match(Op1, m_One())) {
@@ -3378,6 +3381,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *Concat = matchOrConcat(I, Builder))
     return replaceInstUsesWith(I, Concat);
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
+    return replaceInstUsesWith(I, V);
+
   if (Instruction *R = foldBinOpShiftWithShift(I))
     return R;
 
@@ -4501,6 +4507,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
       return BinaryOperator::CreateXor(Or, ConstantExpr::getNot(C1));
     }
 
+    if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
+      return replaceInstUsesWith(I, V);
+
     // Convert xor ([trunc] (ashr X, BW-1)), C =>
     //   select(X >s -1, C, ~C)
     // The ashr creates "AllZeroOrAllOne's", which then optionally inverses the
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 1539fa9a3269e1..83bf40f0b4e358 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1539,6 +1539,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
       return I;
 
+    if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
+      return I;
+
     if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
       return NewCall;
   }
@@ -4237,3 +4240,22 @@ InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
 
   return nullptr;
 }
+
+Instruction *
+InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
+  assert(II.isCommutative());
+
+  PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
+  PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
+
+  if (!LHS || !RHS)
+    return nullptr;
+
+  if (matchSymmetricPhiNodesPair(LHS, RHS)) {
+    replaceOperand(II, 0, LHS->getIncomingValue(0));
+    replaceOperand(II, 1, LHS->getIncomingValue(1));
+    return &II;
+  }
+
+  return nullptr;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 1d50fa9b6bf74b..2a22905e17b133 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -278,6 +278,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                               IntrinsicInst &Tramp);
   Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
 
+  // match a pair of Phi Nodes like
+  // phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
+  bool matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);
+
+  // Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
+  // while op is a commutative intrinsic call
+  Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
+
   Value *simplifyMaskedLoad(IntrinsicInst &II);
   Instruction *simplifyMaskedStore(IntrinsicInst &II);
   Instruction *simplifyMaskedGather(IntrinsicInst &II);
@@ -492,6 +500,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// X % (C0 * C1)
   Value *SimplifyAddWithRemainder(BinaryOperator &I);
 
+  // Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
+  // while Binop is commutative.
+  Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
+                                        Value *RHS);
+
   // Binary Op helper for select operations where the expression can be
   // efficiently reorganized.
   Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index e5566578869ddc..8e1b8b40e0ed63 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -207,6 +207,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Value *V = foldUsingDistributiveLaws(I))
     return replaceInstUsesWith(I, V);
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
+    return replaceInstUsesWith(I, V);
+
   Type *Ty = I.getType();
   const unsigned BitWidth = Ty->getScalarSizeInBits();
   const bool HasNSW = I.hasNoSignedWrap();
@@ -779,6 +782,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
     return replaceInstUsesWith(I, V);
 
+  if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
+    return replaceInstUsesWith(I, V);
+
   if (I.hasAllowReassoc())
     if (Instruction *FoldedMul = foldFMulReassoc(I))
       return FoldedMul;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index a7ddadc25de43c..b168018e5f89d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1096,6 +1096,50 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
   return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
 }
 
+bool InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
+
+  if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
+    return false;
+
+  BasicBlock *B0 = LHS->getIncomingBlock(0);
+  BasicBlock *B1 = LHS->getIncomingBlock(1);
+
+  bool RHSContainB0 = RHS->getBasicBlockIndex(B0) != -1;
+  bool RHSContainB1 = RHS->getBasicBlockIndex(B1) != -1;
+
+  if (!RHSContainB0 || !RHSContainB1)
+    return false;
+
+  Value *N1 = LHS->getIncomingValueForBlock(B0);
+  Value *N2 = LHS->getIncomingValueForBlock(B1);
+  Value *N3 = RHS->getIncomingValueForBlock(B0);
+  Value *N4 = RHS->getIncomingValueForBlock(B1);
+
+  return N1 == N4 && N2 == N3;
+}
+
+Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
+                                                        Value *Op0,
+                                                        Value *Op1) {
+  assert(I.isCommutative() && "Instruction Should be commutative");
+
+  PHINode *LHS = dyn_cast<PHINode>(Op0);
+  PHINode *RHS = dyn_cast<PHINode>(Op1);
+
+  if (!LHS || !RHS)
+    return nullptr;
+
+  if (matchSymmetricPhiNodesPair(LHS, RHS)) {
+    Value *BI = Builder.CreateBinOp(I.getOpcode(), LHS->getIncomingValue(0),
+                                    LHS->getIncomingValue(1));
+    if (auto *BO = dyn_cast<BinaryOperator>(BI))
+      BO->copyIRFlags(&I);
+    return BI;
+  }
+
+  return nullptr;
+}
+
 Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
                                                         Value *LHS,
                                                         Value *RHS) {
diff --git a/llvm/test/Transforms/InstCombine/commutative-operation-over-phis.ll b/llvm/test/Transforms/InstCombine/commutative-operation-over-phis.ll
new file mode 100644
index 00000000000000..1729fb15a8dfe1
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/commutative-operation-over-phis.ll
@@ -0,0 +1,360 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+declare void @dummy()
+
+declare i32 @llvm.smax.i32(i32 %a, i32 %b)
+declare i32 @llvm.smin.i32(i32 %a, i32 %b)
+declare i32 @llvm.umax.i32(i32 %a, i32 %b)
+declare i32 @llvm.umin.i32(i32 %a, i32 %b)
+
+define i8 @fold_phi_mul(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_mul(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = mul i8 [[A]], [[B]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = mul i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+define i8 @fold_phi_mul_commute(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_mul_commute(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = mul i8 [[A]], [[B]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%a, %then], [%b, %entry]
+  %ret = mul i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+
+define i8 @fold_phi_mul_notopt(i1 %c, i8 %a, i8 %b, i8 %d)  {
+; CHECK-LABEL: define i8 @fold_phi_mul_notopt(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[PHI1:%.*]] = phi i8 [ [[A]], [[ENTRY:%.*]] ], [ [[B]], [[THEN]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i8 [ [[B]], [[ENTRY]] ], [ [[D]], [[THEN]] ]
+; CHECK-NEXT:    [[RET:%.*]] = mul i8 [[PHI1]], [[PHI2]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%d, %then]
+  %ret = mul i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+
+define i8 @fold_phi_sub(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_sub(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[PHI1:%.*]] = phi i8 [ [[A]], [[ENTRY:%.*]] ], [ [[B]], [[THEN]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i8 [ [[B]], [[ENTRY]] ], [ [[A]], [[THEN]] ]
+; CHECK-NEXT:    [[RET:%.*]] = sub i8 [[PHI1]], [[PHI2]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = sub i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+
+define i8 @fold_phi_add(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_add(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = add i8 [[A]], [[B]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = add i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+define i8 @fold_phi_and(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_and(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = and i8 [[A]], [[B]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = and i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+define i8 @fold_phi_or(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_or(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = or i8 [[A]], [[B]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = or i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+
+define i8 @fold_phi_xor(i1 %c, i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @fold_phi_xor(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[PHI1:%.*]] = phi i8 [ [[A]], [[ENTRY:%.*]] ], [ [[B]], [[THEN]] ]
+; CHECK-NEXT:    [[PHI2:%.*]] = phi i8 [ [[B]], [[ENTRY]] ], [ [[A]], [[THEN]] ]
+; CHECK-NEXT:    [[RET:%.*]] = xor i8 [[PHI1]], [[PHI2]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i8 [%a, %entry], [%b, %then]
+  %phi2 = phi i8 [%b, %entry], [%a, %then]
+  %ret = xor i8 %phi1, %phi2
+  ret i8 %ret
+}
+
+
+define float @fold_phi_fadd(i1 %c, float %a, float %b)  {
+; CHECK-LABEL: define float @fold_phi_fadd(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = fadd float [[A]], [[B]]
+; CHECK-NEXT:    ret float [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi float [%a, %entry], [%b, %then]
+  %phi2 = phi float [%b, %entry], [%a, %then]
+  %ret = fadd float %phi1, %phi2
+  ret float %ret
+}
+
+
+define float @fold_phi_fmul(i1 %c, float %a, float %b)  {
+; CHECK-LABEL: define float @fold_phi_fmul(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = fmul float [[A]], [[B]]
+; CHECK-NEXT:    ret float [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi float [%a, %entry], [%b, %then]
+  %phi2 = phi float [%b, %entry], [%a, %then]
+  %ret = fmul float %phi1, %phi2
+  ret float %ret
+}
+
+
+define i32 @fold_phi_smax(i1 %c, i32 %a, i32 %b)  {
+; CHECK-LABEL: define i32 @fold_phi_smax(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.smax.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i32 [%a, %entry], [%b, %then]
+  %phi2 = phi i32 [%b, %entry], [%a, %then]
+  %ret = call i32 @llvm.smax.i32(i32  %phi1, i32 %phi2)
+  ret i32 %ret
+}
+
+
+define i32 @fold_phi_smin(i1 %c, i32 %a, i32 %b)  {
+; CHECK-LABEL: define i32 @fold_phi_smin(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i32 [%a, %entry], [%b, %then]
+  %phi2 = phi i32 [%b, %entry], [%a, %then]
+  %ret = call i32 @llvm.smin.i32(i32  %phi1, i32 %phi2)
+  ret i32 %ret
+}
+
+
+define i32 @fold_phi_umax(i1 %c, i32 %a, i32 %b)  {
+; CHECK-LABEL: define i32 @fold_phi_umax(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.umax.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i32 [%a, %entry], [%b, %then]
+  %phi2 = phi i32 [%b, %entry], [%a, %then]
+  %ret = call i32 @llvm.umax.i32(i32  %phi1, i32 %phi2)
+  ret i32 %ret
+}
+
+define i32 @fold_phi_umin(i1 %c, i32 %a, i32 %b)  {
+; CHECK-LABEL: define i32 @fold_phi_umin(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C]], label [[THEN:%.*]], label [[END:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    call void @dummy()
+; CHECK-NEXT:    br label [[END]]
+; CHECK:       end:
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+entry:
+  br i1 %c, label %then, label %end
+then:
+  call void @dummy()
+  br label %end
+end:
+  %phi1 = phi i32 [%a, %entry], [%b, %then]
+  %phi2 = phi i32 [%b, %entry], [%a, %then]
+  %ret = call i32 @llvm.umin.i32(i32  %phi1, i32 %phi2)
+  ret i32 %ret
+}

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please wait for additional approval from other reviewers.

Copy link

github-actions bot commented Dec 18, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.


if (!LHS || !RHS)
return nullptr;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to restrict to 2 incoming values per phi here? matchSymmetricPhiNodesPair could return true for phis with 10^10 incoming values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain more ? I think currentmatchSymmetricPhiNodesPair returns false for 2 incoming values per phi case.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more nits.

llvm/lib/Transforms/InstCombine/InstCombineInternal.h Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineInternal.h Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineInternal.h Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XChy XChy merged commit 8674a02 into llvm:main Dec 21, 2023
4 checks passed
@sun-jacobi sun-jacobi deleted the symmetry-phi branch December 21, 2023 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missing optimization : fold mul (phi a, b), (phi b, a) to mul a, b
6 participants