diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index e70d3f98e5e04..8055eb9a82d64 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2549,6 +2549,19 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { EVT VT = N0.getValueType(); SDLoc DL(N); + // fold (add_sat x, undef) -> -1 + if (N0.isUndef() || N1.isUndef()) + return DAG.getAllOnesConstant(DL, VT); + + // fold (add_sat c1, c2) -> c3 + if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1})) + return C; + + // canonicalize constant to RHS + if (DAG.isConstantIntBuildVectorOrConstantInt(N0) && + !DAG.isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(Opcode, DL, VT, N1, N0); + // fold vector ops if (VT.isVector()) { // TODO SimplifyVBinOp @@ -2556,20 +2569,6 @@ SDValue DAGCombiner::visitADDSAT(SDNode *N) { // fold (add_sat x, 0) -> x, vector edition if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) return N0; - if (ISD::isConstantSplatVectorAllZeros(N0.getNode())) - return N1; - } - - // fold (add_sat x, undef) -> -1 - if (N0.isUndef() || N1.isUndef()) - return DAG.getAllOnesConstant(DL, VT); - - if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) { - // canonicalize constant to RHS - if (!DAG.isConstantIntBuildVectorOrConstantInt(N1)) - return DAG.getNode(Opcode, DL, VT, N1, N0); - // fold (add_sat c1, c2) -> c3 - return DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}); } // fold (add_sat x, 0) -> x @@ -3606,15 +3605,6 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { EVT VT = N0.getValueType(); SDLoc DL(N); - // fold vector ops - if (VT.isVector()) { - // TODO SimplifyVBinOp - - // fold (sub_sat x, 0) -> x, vector edition - if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) - return N0; - } - // fold (sub_sat x, undef) -> 0 if (N0.isUndef() || N1.isUndef()) return DAG.getConstant(0, DL, VT); @@ -3627,6 +3617,15 @@ SDValue DAGCombiner::visitSUBSAT(SDNode *N) { if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1})) return C; + // fold vector ops + if (VT.isVector()) { + // TODO SimplifyVBinOp + + // fold (sub_sat x, 0) -> x, vector edition + if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) + return N0; + } + // fold (sub_sat x, 0) -> x if (isNullConstant(N1)) return N0;