Skip to content

Commit

Permalink
[InstCombine] fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b) while …
Browse files Browse the repository at this point in the history
…Binop is commutative. (#75765)

Alive2 proof: https://alive2.llvm.org/ce/z/2P8gq-
This patch closes #73905
  • Loading branch information
sun-jacobi committed Dec 21, 2023
1 parent 791200b commit 8674a02
Show file tree
Hide file tree
Showing 4 changed files with 735 additions and 0 deletions.
22 changes: 22 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
return I;

if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
return I;

if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
}
Expand Down Expand Up @@ -4237,3 +4240,22 @@ InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {

return nullptr;
}

Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
assert(II.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
replaceOperand(II, 0, P->first);
replaceOperand(II, 1, P->second);
return &II;
}

return nullptr;
}
15 changes: 15 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,16 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
IntrinsicInst &Tramp);
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);

// Match a pair of Phi Nodes like
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
// Return the matched two operands.
std::optional<std::pair<Value *, Value *>>
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);

// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
// while op is a commutative intrinsic call.
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);

Value *simplifyMaskedLoad(IntrinsicInst &II);
Instruction *simplifyMaskedStore(IntrinsicInst &II);
Instruction *simplifyMaskedGather(IntrinsicInst &II);
Expand Down Expand Up @@ -492,6 +502,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// X % (C0 * C1)
Value *SimplifyAddWithRemainder(BinaryOperator &I);

// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
// while Binop is commutative.
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
Value *RHS);

// Binary Op helper for select operations where the expression can be
// efficiently reorganized.
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
Expand Down
53 changes: 53 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,54 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}

std::optional<std::pair<Value *, Value *>>
InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
if (LHS->getParent() != RHS->getParent())
return std::nullopt;

if (LHS->getNumIncomingValues() < 2)
return std::nullopt;

if (!equal(LHS->blocks(), RHS->blocks()))
return std::nullopt;

Value *L0 = LHS->getIncomingValue(0);
Value *R0 = RHS->getIncomingValue(0);

for (unsigned I = 1, E = LHS->getNumIncomingValues(); I != E; ++I) {
Value *L1 = LHS->getIncomingValue(I);
Value *R1 = RHS->getIncomingValue(I);

if ((L0 == L1 && R0 == R1) || (L0 == R1 && R0 == L1))
continue;

return std::nullopt;
}

return std::optional(std::pair(L0, R0));
}

Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
Value *Op0,
Value *Op1) {
assert(I.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(Op0);
PHINode *RHS = dyn_cast<PHINode>(Op1);

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
}

return nullptr;
}

Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Value *LHS,
Value *RHS) {
Expand Down Expand Up @@ -1529,6 +1577,11 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;

if (BO.isCommutative()) {
if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
return replaceInstUsesWith(BO, V);
}

// Fold if there is at least one specific constant value in phi0 or phi1's
// incoming values that comes from the same block and this specific constant
// value can be used to do optimization for specific binary operator.
Expand Down

0 comments on commit 8674a02

Please sign in to comment.