diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 62c331e6bcdda..07ce48bfac5a2 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -3862,12 +3862,9 @@ static unsigned inverseMinMax(unsigned Opc) { } } -SDValue AMDGPUTargetLowering::performFNegCombine(SDNode *N, - DAGCombinerInfo &DCI) const { - SelectionDAG &DAG = DCI.DAG; - SDValue N0 = N->getOperand(0); - EVT VT = N->getValueType(0); - +/// \return true if it's profitable to try to push an fneg into its source +/// instruction. +bool AMDGPUTargetLowering::shouldFoldFNegIntoSrc(SDNode *N, SDValue N0) { unsigned Opc = N0.getOpcode(); // If the input has multiple uses and we can either fold the negate down, or @@ -3878,13 +3875,27 @@ SDValue AMDGPUTargetLowering::performFNegCombine(SDNode *N, // This may be able to fold into the source, but at a code size cost. Don't // fold if the fold into the user is free. if (allUsesHaveSourceMods(N, 0)) - return SDValue(); + return false; } else { if (fnegFoldsIntoOp(Opc) && (allUsesHaveSourceMods(N) || !allUsesHaveSourceMods(N0.getNode()))) - return SDValue(); + return false; } + return true; +} + +SDValue AMDGPUTargetLowering::performFNegCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + SelectionDAG &DAG = DCI.DAG; + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + + unsigned Opc = N0.getOpcode(); + + if (!shouldFoldFNegIntoSrc(N, N0)) + return SDValue(); + SDLoc SL(N); switch (Opc) { case ISD::FADD: { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h index 0264a7ab950f3..969e4c9138532 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -156,6 +156,7 @@ class AMDGPUTargetLowering : public TargetLowering { return Val.getOpcode() == ISD::BITCAST ? Val.getOperand(0) : Val; } + static bool shouldFoldFNegIntoSrc(SDNode *FNeg, SDValue FNegSrc); static bool allUsesHaveSourceMods(const SDNode *N, unsigned CostThreshold = 4); bool isFAbsFree(EVT VT) const override;