diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp index f6f187923e61f..a9dc094c2cfaf 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp @@ -51,6 +51,8 @@ class AMDGPULibCalls { const TargetMachine *TM; + bool UnsafeFPMath = false; + // -fuse-native. bool AllNative = false; @@ -73,10 +75,10 @@ class AMDGPULibCalls { bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); // pow/powr/pown - bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); // rootn - bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); // fma/mad bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); @@ -90,10 +92,10 @@ class AMDGPULibCalls { bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo); // sqrt - bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); // sin/cos - bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo, + bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo, AliasAnalysis *AA); // __read_pipe/__write_pipe @@ -113,7 +115,9 @@ class AMDGPULibCalls { protected: CallInst *CI; - bool isUnsafeMath(const CallInst *CI) const; + bool isUnsafeMath(const FPMathOperator *FPOp) const; + + bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; void replaceCall(Value *With) { CI->replaceAllUsesWith(With); @@ -125,6 +129,7 @@ class AMDGPULibCalls { bool fold(CallInst *CI, AliasAnalysis *AA = nullptr); + void initFunction(const Function &F); void initNativeFuncs(); // Replace a normal math function call with that native version @@ -445,13 +450,18 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName, return AMDGPULibFunc::parse(FMangledName, FInfo); } -bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const { - if (auto Op = dyn_cast(CI)) - if (Op->isFast()) - return true; - const Function *F = CI->getParent()->getParent(); - Attribute Attr = F->getFnAttribute("unsafe-fp-math"); - return Attr.getValueAsBool(); +bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { + return UnsafeFPMath || FPOp->isFast(); +} + +bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( + const FPMathOperator *FPOp) const { + // TODO: Refine to approxFunc or contract + return isUnsafeMath(FPOp); +} + +void AMDGPULibCalls::initFunction(const Function &F) { + UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool(); } bool AMDGPULibCalls::useNativeFunc(const StringRef F) const { @@ -620,65 +630,61 @@ bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) { if (TDOFold(CI, FInfo)) return true; - // Under unsafe-math, evaluate calls if possible. - // According to Brian Sumner, we can do this for all f32 function calls - // using host's double function calls. - if (isUnsafeMath(CI) && evaluateCall(CI, FInfo)) - return true; + if (FPMathOperator *FPOp = dyn_cast(CI)) { + // Under unsafe-math, evaluate calls if possible. + // According to Brian Sumner, we can do this for all f32 function calls + // using host's double function calls. + if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo)) + return true; - // Copy fast flags from the original call. - if (const FPMathOperator *FPOp = dyn_cast(CI)) + // Copy fast flags from the original call. B.setFastMathFlags(FPOp->getFastMathFlags()); - // Specialized optimizations for each function call - switch (FInfo.getId()) { - case AMDGPULibFunc::EI_RECIP: - // skip vector function - assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || - FInfo.getPrefix() == AMDGPULibFunc::HALF) && - "recip must be an either native or half function"); - return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo); - - case AMDGPULibFunc::EI_DIVIDE: - // skip vector function - assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || - FInfo.getPrefix() == AMDGPULibFunc::HALF) && - "divide must be an either native or half function"); - return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo); - - case AMDGPULibFunc::EI_POW: - case AMDGPULibFunc::EI_POWR: - case AMDGPULibFunc::EI_POWN: - return fold_pow(CI, B, FInfo); - - case AMDGPULibFunc::EI_ROOTN: - // skip vector function - return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo); - - case AMDGPULibFunc::EI_FMA: - case AMDGPULibFunc::EI_MAD: - case AMDGPULibFunc::EI_NFMA: - // skip vector function - return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo); - - case AMDGPULibFunc::EI_SQRT: - return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo); - case AMDGPULibFunc::EI_COS: - case AMDGPULibFunc::EI_SIN: - if ((getArgType(FInfo) == AMDGPULibFunc::F32 || - getArgType(FInfo) == AMDGPULibFunc::F64) - && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX)) - return fold_sincos(CI, B, FInfo, AA); - - break; - case AMDGPULibFunc::EI_READ_PIPE_2: - case AMDGPULibFunc::EI_READ_PIPE_4: - case AMDGPULibFunc::EI_WRITE_PIPE_2: - case AMDGPULibFunc::EI_WRITE_PIPE_4: - return fold_read_write_pipe(CI, B, FInfo); - - default: - break; + // Specialized optimizations for each function call + switch (FInfo.getId()) { + case AMDGPULibFunc::EI_POW: + case AMDGPULibFunc::EI_POWR: + case AMDGPULibFunc::EI_POWN: + return fold_pow(FPOp, B, FInfo); + case AMDGPULibFunc::EI_ROOTN: + return fold_rootn(FPOp, B, FInfo); + case AMDGPULibFunc::EI_SQRT: + return fold_sqrt(FPOp, B, FInfo); + case AMDGPULibFunc::EI_COS: + case AMDGPULibFunc::EI_SIN: + return fold_sincos(FPOp, B, FInfo, AA); + case AMDGPULibFunc::EI_RECIP: + // skip vector function + assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || + FInfo.getPrefix() == AMDGPULibFunc::HALF) && + "recip must be an either native or half function"); + return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo); + + case AMDGPULibFunc::EI_DIVIDE: + // skip vector function + assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE || + FInfo.getPrefix() == AMDGPULibFunc::HALF) && + "divide must be an either native or half function"); + return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo); + case AMDGPULibFunc::EI_FMA: + case AMDGPULibFunc::EI_MAD: + case AMDGPULibFunc::EI_NFMA: + // skip vector function + return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo); + default: + break; + } + } else { + // Specialized optimizations for each function call + switch (FInfo.getId()) { + case AMDGPULibFunc::EI_READ_PIPE_2: + case AMDGPULibFunc::EI_READ_PIPE_4: + case AMDGPULibFunc::EI_WRITE_PIPE_2: + case AMDGPULibFunc::EI_WRITE_PIPE_4: + return fold_read_write_pipe(CI, B, FInfo); + default: + break; + } } return false; @@ -796,7 +802,7 @@ static double log2(double V) { } } -bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { assert((FInfo.getId() == AMDGPULibFunc::EI_POW || FInfo.getId() == AMDGPULibFunc::EI_POWR || @@ -827,7 +833,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B, } // No unsafe math , no constant argument, do nothing - if (!isUnsafeMath(CI) && !CF && !CINT && !CZero) + if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero) return false; // 0x1111111 means that we don't do anything for this call. @@ -885,7 +891,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B, } } - if (!isUnsafeMath(CI)) + if (!isUnsafeMath(FPOp)) return false; // Unsafe Math optimization @@ -1079,10 +1085,14 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B, return true; } -bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { - Value *opr0 = CI->getArgOperand(0); - Value *opr1 = CI->getArgOperand(1); + // skip vector function + if (getVecSize(FInfo) != 1) + return false; + + Value *opr0 = FPOp->getOperand(0); + Value *opr1 = FPOp->getOperand(1); ConstantInt *CINT = dyn_cast(opr1); if (!CINT) { @@ -1188,8 +1198,11 @@ FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M, } // fold sqrt -> native_sqrt (x) -bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { + if (!isUnsafeMath(FPOp)) + return false; + if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) && (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) { if (FunctionCallee FPExpr = getNativeFunction( @@ -1206,10 +1219,16 @@ bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B, } // fold sin, cos -> sincos. -bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &fInfo, AliasAnalysis *AA) { assert(fInfo.getId() == AMDGPULibFunc::EI_SIN || fInfo.getId() == AMDGPULibFunc::EI_COS); + + if ((getArgType(fInfo) != AMDGPULibFunc::F32 && + getArgType(fInfo) != AMDGPULibFunc::F64) || + fInfo.getPrefix() != AMDGPULibFunc::NOPFX) + return false; + bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; Value *CArgVal = CI->getArgOperand(0); @@ -1651,6 +1670,8 @@ bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) { if (skipFunction(F)) return false; + Simplifier.initFunction(F); + bool Changed = false; auto AA = &getAnalysis().getAAResults(); @@ -1675,6 +1696,7 @@ PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F, FunctionAnalysisManager &AM) { AMDGPULibCalls Simplifier(&TM); Simplifier.initNativeFuncs(); + Simplifier.initFunction(F); bool Changed = false; auto AA = &AM.getResult(F); @@ -1701,6 +1723,8 @@ bool AMDGPUUseNativeCalls::runOnFunction(Function &F) { if (skipFunction(F) || UseNative.empty()) return false; + Simplifier.initFunction(F); + bool Changed = false; for (auto &BB : F) { for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) { @@ -1721,6 +1745,7 @@ PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F, AMDGPULibCalls Simplifier; Simplifier.initNativeFuncs(); + Simplifier.initFunction(F); bool Changed = false; for (auto &BB : F) {