diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 21121d71a5fddc..8317a2e146a08c 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4070,22 +4070,34 @@ class TargetLowering : public TargetLoweringBase { NegatibleCost &Cost, unsigned Depth = 0) const; - /// This is the helper function to return the newly negated expression only - /// when the cost is cheaper. - SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOps, bool OptForSize, - unsigned Depth = 0) const { + SDValue getCheaperOrNeutralNegatedExpression( + SDValue Op, SelectionDAG &DAG, bool LegalOps, bool OptForSize, + const NegatibleCost CostThreshold = NegatibleCost::Neutral, + unsigned Depth = 0) const { NegatibleCost Cost = NegatibleCost::Expensive; SDValue Neg = getNegatedExpression(Op, DAG, LegalOps, OptForSize, Cost, Depth); - if (Neg && Cost == NegatibleCost::Cheaper) + if (!Neg) + return SDValue(); + + if (Cost <= CostThreshold) return Neg; + // Remove the new created node to avoid the side effect to the DAG. - if (Neg && Neg->use_empty()) + if (Neg->use_empty()) DAG.RemoveDeadNode(Neg.getNode()); return SDValue(); } + /// This is the helper function to return the newly negated expression only + /// when the cost is cheaper. + SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOps, bool OptForSize, + unsigned Depth = 0) const { + return getCheaperOrNeutralNegatedExpression(Op, DAG, LegalOps, OptForSize, + NegatibleCost::Cheaper, Depth); + } + /// This is the helper function to return the newly negated expression if /// the cost is not expensive. SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 8c87fc4acd3a66..73172bb5c1de1b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -387,6 +387,10 @@ namespace { SDValue PromoteExtend(SDValue Op); bool PromoteLoad(SDValue Op); + SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, + SDValue RHS, SDValue True, SDValue False, + ISD::CondCode CC); + /// Call the node-specific routine that knows how to fold each /// particular type of node. If that doesn't do anything, try the /// target-specific DAG combines. @@ -10392,21 +10396,20 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS, } /// Generate Min/Max node -static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, - SDValue RHS, SDValue True, SDValue False, - ISD::CondCode CC, const TargetLowering &TLI, - SelectionDAG &DAG) { +SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, + SDValue RHS, SDValue True, + SDValue False, ISD::CondCode CC) { if ((LHS == True && RHS == False) || (LHS == False && RHS == True)) return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG); // If we can't directly match this, try to see if we can pull an fneg out of // the select. - if (True.getOpcode() != ISD::FNEG) + SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression( + True, DAG, LegalOperations, ForCodeSize); + if (!NegTrue) return SDValue(); - ConstantFPSDNode *CRHS = dyn_cast(RHS); - ConstantFPSDNode *CFalse = dyn_cast(False); - SDValue NegTrue = True.getOperand(0); + HandleSDNode NegTrueHandle(NegTrue); // Try to unfold an fneg from the select if we are comparing the negated // constant. @@ -10414,14 +10417,18 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS, // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K)) // // TODO: Handle fabs - if (LHS == NegTrue && CFalse && CRHS) { - APFloat NegRHS = neg(CRHS->getValueAPF()); - if (NegRHS == CFalse->getValueAPF()) { - SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue, - False, CC, TLI, DAG); - if (Combined) + if (LHS == NegTrue) { + // If we can't directly match this, try to see if we can pull an fneg out of + // the select. + SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression( + RHS, DAG, LegalOperations, ForCodeSize); + if (NegRHS) { + HandleSDNode NegRHSHandle(NegRHS); + if (NegRHS == False) { + SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue, + False, CC, TLI, DAG); return DAG.getNode(ISD::FNEG, DL, VT, Combined); - return SDValue(); + } } } @@ -10812,8 +10819,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { // // This is OK if we don't care what happens if either operand is a NaN. if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI)) - if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, - CC, TLI, DAG)) + if (SDValue FMinMax = + combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC)) return FMinMax; // Use 'unsigned add with overflow' to optimize an unsigned saturating add. @@ -11325,8 +11332,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { // NaN. // if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) { - if (SDValue FMinMax = - combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG)) + if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC)) return FMinMax; } diff --git a/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll b/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll index 23fdf07084705f..664272ef8c0989 100644 --- a/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll +++ b/llvm/test/CodeGen/ARM/unsafe-fneg-select-minnum-maxnum-combine.ll @@ -67,13 +67,10 @@ define float @select_fsub0_or_8_cmp_olt_fsub1_neg8_f32(float %a, float %b) #0 { ; CHECK-NEXT: vmov.f32 s0, #4.000000e+00 ; CHECK-NEXT: vmov s2, r0 ; CHECK-NEXT: vmov.f32 s4, #-8.000000e+00 -; CHECK-NEXT: vmov.f32 s8, #8.000000e+00 -; CHECK-NEXT: vsub.f32 s6, s0, s2 -; CHECK-NEXT: vsub.f32 s0, s2, s0 -; CHECK-NEXT: vcmp.f32 s4, s6 -; CHECK-NEXT: vmrs APSR_nzcv, fpscr -; CHECK-NEXT: vselgt.f32 s0, s0, s8 +; CHECK-NEXT: vsub.f32 s0, s0, s2 +; CHECK-NEXT: vminnm.f32 s0, s0, s4 ; CHECK-NEXT: vmov r0, s0 +; CHECK-NEXT: eor r0, r0, #-2147483648 ; CHECK-NEXT: mov pc, lr %sub.0 = fsub nnan nsz float 4.0, %a %sub.1 = fsub nnan nsz float %a, 4.0 @@ -88,13 +85,10 @@ define float @select_fsub0_or_neg8_cmp_olt_fsub1_8_f32(float %a, float %b) #0 { ; CHECK-NEXT: vmov.f32 s0, #4.000000e+00 ; CHECK-NEXT: vmov s2, r0 ; CHECK-NEXT: vmov.f32 s4, #8.000000e+00 -; CHECK-NEXT: vmov.f32 s8, #-8.000000e+00 -; CHECK-NEXT: vsub.f32 s6, s0, s2 -; CHECK-NEXT: vsub.f32 s0, s2, s0 -; CHECK-NEXT: vcmp.f32 s4, s6 -; CHECK-NEXT: vmrs APSR_nzcv, fpscr -; CHECK-NEXT: vselgt.f32 s0, s0, s8 +; CHECK-NEXT: vsub.f32 s0, s0, s2 +; CHECK-NEXT: vminnm.f32 s0, s0, s4 ; CHECK-NEXT: vmov r0, s0 +; CHECK-NEXT: eor r0, r0, #-2147483648 ; CHECK-NEXT: mov pc, lr %sub.0 = fsub nnan nsz float 4.0, %a %sub.1 = fsub nnan nsz float %a, 4.0 @@ -108,15 +102,11 @@ define float @select_mul4_or_neg8_cmp_olt_mulneg4_8_f32(float %a, float %b) #0 { ; CHECK: @ %bb.0: ; CHECK-NEXT: vmov.f32 s0, #-4.000000e+00 ; CHECK-NEXT: vmov s2, r0 -; CHECK-NEXT: vmov.f32 s6, #8.000000e+00 -; CHECK-NEXT: vmov.f32 s4, #4.000000e+00 -; CHECK-NEXT: vmov.f32 s8, #-8.000000e+00 +; CHECK-NEXT: vmov.f32 s4, #8.000000e+00 ; CHECK-NEXT: vmul.f32 s0, s2, s0 -; CHECK-NEXT: vmul.f32 s2, s2, s4 -; CHECK-NEXT: vcmp.f32 s6, s0 -; CHECK-NEXT: vmrs APSR_nzcv, fpscr -; CHECK-NEXT: vselgt.f32 s0, s2, s8 +; CHECK-NEXT: vminnm.f32 s0, s0, s4 ; CHECK-NEXT: vmov r0, s0 +; CHECK-NEXT: eor r0, r0, #-2147483648 ; CHECK-NEXT: mov pc, lr %mul.0 = fmul nnan nsz float %a, 4.0 %mul.1 = fmul nnan nsz float %a, -4.0 @@ -130,15 +120,11 @@ define float @select_mul4_or_8_cmp_olt_mulneg4_neg8_f32(float %a, float %b) #0 { ; CHECK: @ %bb.0: ; CHECK-NEXT: vmov.f32 s0, #-4.000000e+00 ; CHECK-NEXT: vmov s2, r0 -; CHECK-NEXT: vmov.f32 s6, #-8.000000e+00 -; CHECK-NEXT: vmov.f32 s4, #4.000000e+00 -; CHECK-NEXT: vmov.f32 s8, #8.000000e+00 +; CHECK-NEXT: vmov.f32 s4, #-8.000000e+00 ; CHECK-NEXT: vmul.f32 s0, s2, s0 -; CHECK-NEXT: vmul.f32 s2, s2, s4 -; CHECK-NEXT: vcmp.f32 s6, s0 -; CHECK-NEXT: vmrs APSR_nzcv, fpscr -; CHECK-NEXT: vselgt.f32 s0, s2, s8 +; CHECK-NEXT: vminnm.f32 s0, s0, s4 ; CHECK-NEXT: vmov r0, s0 +; CHECK-NEXT: eor r0, r0, #-2147483648 ; CHECK-NEXT: mov pc, lr %mul.0 = fmul nnan nsz float %a, 4.0 %mul.1 = fmul nnan nsz float %a, -4.0 @@ -194,15 +180,11 @@ define float @select_mulneg4_or_neg8_cmp_olt_mul4_8_f32(float %a, float %b) #0 { ; CHECK: @ %bb.0: ; CHECK-NEXT: vmov.f32 s0, #4.000000e+00 ; CHECK-NEXT: vmov s2, r0 -; CHECK-NEXT: vmov.f32 s6, #8.000000e+00 -; CHECK-NEXT: vmov.f32 s4, #-4.000000e+00 -; CHECK-NEXT: vmov.f32 s8, #-8.000000e+00 +; CHECK-NEXT: vmov.f32 s4, #8.000000e+00 ; CHECK-NEXT: vmul.f32 s0, s2, s0 -; CHECK-NEXT: vmul.f32 s2, s2, s4 -; CHECK-NEXT: vcmp.f32 s6, s0 -; CHECK-NEXT: vmrs APSR_nzcv, fpscr -; CHECK-NEXT: vselgt.f32 s0, s2, s8 +; CHECK-NEXT: vminnm.f32 s0, s0, s4 ; CHECK-NEXT: vmov r0, s0 +; CHECK-NEXT: eor r0, r0, #-2147483648 ; CHECK-NEXT: mov pc, lr %mul.0 = fmul nnan nsz float %a, -4.0 %mul.1 = fmul nnan nsz float %a, 4.0