diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 5476ef8797143..735cec8ecc062 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -38,6 +38,7 @@ #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/RuntimeLibcalls.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" #include "llvm/CodeGen/SelectionDAGNodes.h" @@ -79,6 +80,7 @@ #include "MatchContext.h" using namespace llvm; +using namespace llvm::SDPatternMatch; #define DEBUG_TYPE "dagcombine" @@ -2697,52 +2699,45 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) { reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1)) return SD; } + + SDValue A, B, C; + // fold ((0-A) + B) -> B-A - if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0))) - return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1)); + if (sd_match(N0, m_Sub(m_Zero(), m_Value(A)))) + return DAG.getNode(ISD::SUB, DL, VT, N1, A); // fold (A + (0-B)) -> A-B - if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0))) - return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1)); + if (sd_match(N1, m_Sub(m_Zero(), m_Value(B)))) + return DAG.getNode(ISD::SUB, DL, VT, N0, B); // fold (A+(B-A)) -> B - if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1)) - return N1.getOperand(0); + if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0)))) + return B; // fold ((B-A)+A) -> B - if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1)) - return N0.getOperand(0); + if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1)))) + return B; // fold ((A-B)+(C-A)) -> (C-B) - if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB && - N0.getOperand(0) == N1.getOperand(1)) - return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0), - N0.getOperand(1)); + if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) && + sd_match(N1, m_Sub(m_Value(C), m_Specific(A)))) + return DAG.getNode(ISD::SUB, DL, VT, C, B); // fold ((A-B)+(B-C)) -> (A-C) - if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB && - N0.getOperand(1) == N1.getOperand(0)) - return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), - N1.getOperand(1)); + if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) && + sd_match(N1, m_Sub(m_Specific(B), m_Value(C)))) + return DAG.getNode(ISD::SUB, DL, VT, A, C); // fold (A+(B-(A+C))) to (B-C) - if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD && - N0 == N1.getOperand(1).getOperand(0)) - return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0), - N1.getOperand(1).getOperand(1)); - // fold (A+(B-(C+A))) to (B-C) - if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD && - N0 == N1.getOperand(1).getOperand(1)) - return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0), - N1.getOperand(1).getOperand(0)); + if (sd_match(N1, m_Sub(m_Value(B), m_Add(m_Specific(N0), m_Value(C))))) + return DAG.getNode(ISD::SUB, DL, VT, B, C); // fold (A+((B-A)+or-C)) to (B+or-C) - if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) && - N1.getOperand(0).getOpcode() == ISD::SUB && - N0 == N1.getOperand(0).getOperand(1)) - return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0), - N1.getOperand(1)); + if (sd_match(N1, + m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)), + m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C))))) + return DAG.getNode(N1.getOpcode(), DL, VT, B, C); // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&