diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 2967532226197..be90250b068f6 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -578,6 +578,18 @@ m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) { return TernaryOpc_match(ISD::INSERT_SUBVECTOR, Base, Sub, Idx); } +template +inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F, + const CCTy &CC) { + return m_Node(ISD::SELECT_CC, L, R, T, F, CC); +} + +template +inline auto m_SelectCCLike(const LTy &L, const RTy &R, const TTy &T, + const FTy &F, const CCTy &CC) { + return m_AnyOf(m_Select(m_SetCC(L, R, CC), T, F), m_SelectCC(L, R, T, F, CC)); +} + // === Binary operations === template diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a43020ee62281..0981a04160928 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4100,18 +4100,17 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y)) if (N1.hasOneUse() && hasUMin(VT)) { SDValue Y; - if (sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero())) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y))) || - sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero()))) + auto MS0 = m_Specific(N0); + auto MVY = m_Value(Y); + auto MZ = m_Zero(); + auto MCC1 = m_SpecificCondCode(ISD::SETULT); + auto MCC2 = m_SpecificCondCode(ISD::SETUGE); + + if (sd_match(N1, m_SelectCCLike(MS0, MVY, MZ, m_Deferred(Y), MCC1)) || + sd_match(N1, m_SelectCCLike(MS0, MVY, m_Deferred(Y), MZ, MCC2)) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC1), MZ, m_Deferred(Y))) || + sd_match(N1, m_VSelect(m_SetCC(MS0, MVY, MCC2), m_Deferred(Y), MZ))) + return DAG.getNode(ISD::UMIN, DL, VT, N0, DAG.getNode(ISD::SUB, DL, VT, N0, Y)); } diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 4e0bf385d72b2..16b997901dc1c 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -859,3 +859,35 @@ TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) { EXPECT_TRUE(sd_match(Vec, DAG.get(), m_AllOnes(true))); } } + +TEST_F(SelectionDAGPatternMatchTest, MatchSelectCCLike) { + using namespace SDPatternMatch; + + SDValue LHS = DAG->getConstant(1, SDLoc(), MVT::i32); + SDValue RHS = DAG->getConstant(2, SDLoc(), MVT::i32); + SDValue TVal = DAG->getConstant(3, SDLoc(), MVT::i32); + SDValue FVal = DAG->getConstant(4, SDLoc(), MVT::i32); + SDValue Select = DAG->getNode(ISD::SELECT_CC, SDLoc(), MVT::i32, LHS, RHS, + TVal, FVal, DAG->getCondCode(ISD::SETLT)); + + ISD::CondCode CC = ISD::SETLT; + EXPECT_TRUE(sd_match( + Select, m_SelectCCLike(m_Specific(LHS), m_Specific(RHS), m_Specific(TVal), + m_Specific(FVal), m_CondCode(CC)))); +} + +TEST_F(SelectionDAGPatternMatchTest, MatchSelectCC) { + using namespace SDPatternMatch; + + SDValue LHS = DAG->getConstant(1, SDLoc(), MVT::i32); + SDValue RHS = DAG->getConstant(2, SDLoc(), MVT::i32); + SDValue TVal = DAG->getConstant(3, SDLoc(), MVT::i32); + SDValue FVal = DAG->getConstant(4, SDLoc(), MVT::i32); + SDValue Select = DAG->getNode(ISD::SELECT_CC, SDLoc(), MVT::i32, LHS, RHS, + TVal, FVal, DAG->getCondCode(ISD::SETLT)); + + ISD::CondCode CC = ISD::SETLT; + EXPECT_TRUE(sd_match(Select, m_SelectCC(m_Specific(LHS), m_Specific(RHS), + m_Specific(TVal), m_Specific(FVal), + m_CondCode(CC)))); +}