diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 0d6df72790632..fc4c0124d00b8 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1716,7 +1716,8 @@ template struct TwoOps_match { }; /// Matches instructions with Opcode and three operands. -template +template struct ThreeOps_match { T0 Op1; T1 Op2; @@ -1728,8 +1729,12 @@ struct ThreeOps_match { template bool match(OpTy *V) { if (V->getValueID() == Value::InstructionVal + Opcode) { auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + if (!Op1.match(I->getOperand(0))) + return false; + if (Op2.match(I->getOperand(1)) && Op3.match(I->getOperand(2))) + return true; + return CommutableOp2Op3 && Op2.match(I->getOperand(2)) && + Op3.match(I->getOperand(1)); } return false; } @@ -1781,6 +1786,14 @@ m_SelectCst(const Cond &C) { return m_Select(C, m_ConstantInt(), m_ConstantInt()); } +/// Match Select(C, LHS, RHS) or Select(C, RHS, LHS) +template +inline ThreeOps_match +m_c_Select(const LHS &L, const RHS &R) { + return ThreeOps_match(m_Value(), L, R); +} + /// Matches FreezeInst. template inline OneOps_match m_Freeze(const OpTy &Op) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 46ce011c5f788..6fe9693581853 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2245,9 +2245,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const Instruction *UI = dyn_cast(U); if (!UI) return false; - return match(UI, - m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || - match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); + return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I))); })) { if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation && I.hasNoSignedWrap(), diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 42c0acd1e45ec..fd38738e3be80 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1736,9 +1736,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *X; if (match(IIOperand, m_Neg(m_Value(X)))) return replaceOperand(*II, 0, X); - if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X))))) - return replaceOperand(*II, 0, X); - if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) + if (match(IIOperand, m_c_Select(m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); Value *Y; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d602a907e72bc..acf01a8f1f7fc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -8437,9 +8437,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { case Instruction::Select: // fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0 if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) && - (match(LHSI, - m_Select(m_Value(), m_Value(X), m_FNeg(m_Deferred(X)))) || - match(LHSI, m_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X))))) + match(LHSI, m_c_Select(m_FNeg(m_Value(X)), m_Deferred(X)))) return replaceOperand(I, 0, X); if (Instruction *NV = FoldOpIntoSelect(I, cast(LHSI))) return NV; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 1991ec82d1e1e..0c84e6fae496f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -3814,10 +3814,7 @@ static bool foldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // These can often be turned into switches and other things. auto IsBinOpOrAnd = [](Value *V) { return match( - V, m_CombineOr( - m_BinOp(), - m_CombineOr(m_Select(m_Value(), m_ImmConstant(), m_Value()), - m_Select(m_Value(), m_Value(), m_ImmConstant())))); + V, m_CombineOr(m_BinOp(), m_c_Select(m_ImmConstant(), m_Value()))); }; if (PN->getType()->isIntegerTy(1) && (IsBinOpOrAnd(PN->getIncomingValue(0)) ||