diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 07d5bfaa31e22..86823ddab19eb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2202,45 +2202,16 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, SDValue TVal = N1.getOperand(1); SDValue FVal = N1.getOperand(2); - // TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity(). - // TODO: Target-specific opcodes could be added. Ex: "isCommutativeBinOp()". - // TODO: With fast-math (NSZ), allow the opposite-sign form of zero? - auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) { - if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) { - switch (Opcode) { - case ISD::FADD: // X + -0.0 --> X - return C->isZero() && C->isNegative(); - case ISD::FSUB: // X - 0.0 --> X - return C->isZero() && !C->isNegative(); - case ISD::FMUL: // X * 1.0 --> X - case ISD::FDIV: // X / 1.0 --> X - return C->isExactlyValue(1.0); - } - } - if (ConstantSDNode *C = isConstOrConstSplat(V)) { - switch (Opcode) { - case ISD::ADD: // X + 0 --> X - case ISD::SUB: // X - 0 --> X - case ISD::SHL: // X << 0 --> X - case ISD::SRA: // X s>> 0 --> X - case ISD::SRL: // X u>> 0 --> X - return C->isZero(); - case ISD::MUL: // X * 1 --> X - return C->isOne(); - } - } - return false; - }; - // This transform increases uses of N0, so freeze it to be safe. // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal) - if (isIdentityConstantForOpcode(Opcode, TVal)) { + unsigned OpNo = ShouldCommuteOperands ? 0 : 1; + if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) { SDValue F0 = DAG.getFreeze(N0); SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags()); return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO); } // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0 - if (isIdentityConstantForOpcode(Opcode, FVal)) { + if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) { SDValue F0 = DAG.getFreeze(N0); SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags()); return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 3b6f6c48b1840..d866e9d4a1dbb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -10749,7 +10749,9 @@ bool llvm::isMinSignedConstant(SDValue V) { bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V, unsigned OperandNo) { - if (auto *Const = dyn_cast(V)) { + // NOTE: The cases should match with IR's ConstantExpr::getBinOpIdentity(). + // TODO: Target-specific opcodes could be added. + if (auto *Const = isConstOrConstSplat(V)) { switch (Opcode) { case ISD::ADD: case ISD::OR: @@ -10774,7 +10776,7 @@ bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V, case ISD::SDIV: return OperandNo == 1 && Const->isOne(); } - } else if (auto *ConstFP = dyn_cast(V)) { + } else if (auto *ConstFP = isConstOrConstSplatFP(V)) { switch (Opcode) { case ISD::FADD: return ConstFP->isZero() &&