diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp index 432c63fb65f49..97c5b9519ea92 100644 --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -9,6 +9,8 @@ // Replaces LLVM IR instructions with vector operands (i.e., the frem // instruction or calls to LLVM intrinsics) with matching calls to functions // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface. +// This happens only when the cost of calling the vector library is not found to +// be more than the cost of the original instruction. // //===----------------------------------------------------------------------===// @@ -20,12 +22,16 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/VFABIDemangler.h" +#include "llvm/Support/InstructionCost.h" #include "llvm/Support/TypeSize.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -96,15 +102,55 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info, Replacement->copyFastMathFlags(&I); } +/// Returns whether the vector library call \p TLIFunc costs more than the +/// original instruction \p I. +static bool isVeclibCallSlower(const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, Instruction &I, + VectorType *VectorTy, CallInst *CI, + Function *TLIFunc) { + SmallVector OpTypes; + for (auto &Op : CI ? CI->args() : I.operands()) + OpTypes.push_back(Op->getType()); + + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost DefaultCost; + if (CI) { + FastMathFlags FMF; + if (auto *FPMO = dyn_cast(CI)) + FMF = FPMO->getFastMathFlags(); + + SmallVector Args(CI->args()); + IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args, + OpTypes, FMF, + dyn_cast(CI)); + DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind); + } else { + assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported"); + auto Op2Info = TTI.getOperandInfo(I.getOperand(1)); + SmallVector OpValues(I.operand_values()); + DefaultCost = TTI.getArithmeticInstrCost( + I.getOpcode(), VectorTy, CostKind, + {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, + Op2Info, OpValues, &I); + } + + InstructionCost VecLibCost = + TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind); + return VecLibCost > DefaultCost; +} + /// Returns true when successfully replaced \p I with a suitable function taking -/// vector arguments, based on available mappings in the \p TLI. Currently only -/// works when \p I is a call to vectorized intrinsic or the frem instruction. +/// vector arguments, based on available mappings in the \p TLI and costs. +/// Currently only works when \p I is a call to vectorized intrinsic or the frem +/// instruction. static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, Instruction &I) { // At the moment VFABI assumes the return type is always widened unless it is // a void type. - auto *VTy = dyn_cast(I.getType()); - ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0)); + auto *VectorTy = dyn_cast(I.getType()); + ElementCount EC(VectorTy ? VectorTy->getElementCount() + : ElementCount::getFixed(0)); // Compute the argument types of the corresponding scalar call and the scalar // function name. For calls, it additionally finds the function to replace @@ -125,9 +171,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, ScalarArgTypes.push_back(VectorArgTy->getElementType()); // When return type is void, set EC to the first vector argument, and // disallow vector arguments with different ECs. - if (EC.isZero()) + if (EC.isZero()) { EC = VectorArgTy->getElementCount(); - else if (EC != VectorArgTy->getElementCount()) + VectorTy = VectorArgTy; + } else if (EC != VectorArgTy->getElementCount()) return false; } else // Exit when it is supposed to be a vector argument but it isn't. @@ -139,8 +186,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule()) : Intrinsic::getName(IID).str(); } else { - assert(VTy && "Return type must be a vector"); - auto *ScalarTy = VTy->getScalarType(); + assert(VectorTy && "Return type must be a vector"); + auto *ScalarTy = VectorTy->getScalarType(); LibFunc Func; if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func)) return false; @@ -200,6 +247,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy, VD->getVectorFnName(), FuncToReplace); + if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc)) + return false; + replaceWithTLIFunction(I, *OptInfo, TLIFunc); LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName << "` with call to `" << TLIFunc->getName() << "`.\n"); @@ -220,13 +270,14 @@ static bool isSupportedInstruction(Instruction *I) { return false; } -static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { +static bool runImpl(const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, Function &F) { bool Changed = false; SmallVector ReplacedCalls; for (auto &I : instructions(F)) { if (!isSupportedInstruction(&I)) continue; - if (replaceWithCallToVeclib(TLI, I)) { + if (replaceWithCallToVeclib(TLI, TTI, I)) { ReplacedCalls.push_back(&I); Changed = true; } @@ -244,7 +295,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { PreservedAnalyses ReplaceWithVeclib::run(Function &F, FunctionAnalysisManager &AM) { const TargetLibraryInfo &TLI = AM.getResult(F); - auto Changed = runImpl(TLI, F); + const TargetTransformInfo &TTI = AM.getResult(F); + auto Changed = runImpl(TLI, TTI, F); if (Changed) { LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: " << NumCallsReplaced << "\n"); @@ -252,6 +304,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F, PreservedAnalyses PA; PA.preserveSet(); PA.preserve(); + PA.preserve(); PA.preserve(); PA.preserve(); PA.preserve(); @@ -269,13 +322,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F, bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) { const TargetLibraryInfo &TLI = getAnalysis().getTLI(F); - return runImpl(TLI, F); + const TargetTransformInfo &TTI = + getAnalysis().getTTI(F); + return runImpl(TLI, TTI, F); } void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); + AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll index 758df0493cc50..d3e1ae338f2ca 100644 --- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll @@ -428,7 +428,7 @@ define @llvm_sin_vscale_f32( %in) #0 { define <2 x double> @frem_f64(<2 x double> %in) { ; CHECK-LABEL: define <2 x double> @frem_f64 ; CHECK-SAME: (<2 x double> [[IN:%.*]]) { -; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]]) +; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]] ; CHECK-NEXT: ret <2 x double> [[TMP1]] ; %1= frem <2 x double> %in, %in diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll index f408df570fdc0..69b16b02adaa2 100644 --- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll +++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll @@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) { define <2 x double> @frem_f64(<2 x double> %in) { ; CHECK-LABEL: @frem_f64( -; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]]) +; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]] ; CHECK-NEXT: ret <2 x double> [[TMP1]] ; %1= frem <2 x double> %in, %in