-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TLI] ReplaceWithVecLib pass uses CostModel #78688
[TLI] ReplaceWithVecLib pass uses CostModel #78688
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Paschalis Mpeis (paschalis-mpeis) ChangesPass Full diff: https://github.com/llvm/llvm-project/pull/78688.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7b0215535a92c8..c57156c00a74e8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -9,6 +9,8 @@
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
+// This happens only when the cost of calling the vector library is not found to
+// be more than the cost of the original instruction.
//
//===----------------------------------------------------------------------===//
@@ -20,11 +22,15 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -95,15 +101,54 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Replacement->copyFastMathFlags(&I);
}
+/// Returns whether the vector library call \p TLIFunc costs more than the
+/// original instruction \p I.
+static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Instruction &I,
+ VectorType *VectorTy, CallInst *CI,
+ Function *TLIFunc) {
+ SmallVector<Type *, 4> OpTypes;
+ for (auto &Op : CI ? CI->args() : I.operands())
+ OpTypes.push_back(Op->getType());
+
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost DefaultCost;
+ if (CI) {
+ FastMathFlags FMF;
+ if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
+ FMF = FPMO->getFastMathFlags();
+
+ SmallVector<const Value *> Args(CI->args());
+ IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
+ OpTypes, FMF,
+ dyn_cast<IntrinsicInst>(CI));
+ DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+ } else {
+ assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
+ auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
+ SmallVector<const Value *, 4> OpValues(I.operand_values());
+ DefaultCost = TTI.getArithmeticInstrCost(
+ I.getOpcode(), VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ Op2Info, OpValues, &I);
+ }
+
+ InstructionCost VecLibCost =
+ TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
+ return VecLibCost > DefaultCost;
+}
+
/// Returns true when successfully replaced \p I with a suitable function taking
-/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem instruction.
+/// vector arguments, based on available mappings in the \p TLI and costs.
+/// Currently only works when \p I is a call to vectorized intrinsic or the frem
+/// instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI,
Instruction &I) {
// At the moment VFABI assumes the return type is always widened unless it is
// a void type.
- auto *VTy = dyn_cast<VectorType>(I.getType());
- ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+ auto *VectorTy = dyn_cast<VectorType>(I.getType());
+ ElementCount EC(VectorTy ? VectorTy->getElementCount() : ElementCount::getFixed(0));
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
@@ -124,9 +169,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
- if (EC.isZero())
+ if (EC.isZero()) {
EC = VectorArgTy->getElementCount();
- else if (EC != VectorArgTy->getElementCount())
+ VectorTy = VectorArgTy;
+ } else if (EC != VectorArgTy->getElementCount())
return false;
} else
// Exit when it is supposed to be a vector argument but it isn't.
@@ -138,8 +184,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
- assert(VTy && "Return type must be a vector");
- auto *ScalarTy = VTy->getScalarType();
+ assert(VectorTy && "Return type must be a vector");
+ auto *ScalarTy = VectorTy->getScalarType();
LibFunc Func;
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
@@ -199,6 +245,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+ if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
+ return false;
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -219,13 +268,14 @@ static bool isSupportedInstruction(Instruction *I) {
return false;
}
-static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
+static bool runImpl(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (!isSupportedInstruction(&I))
continue;
- if (replaceWithCallToVeclib(TLI, I)) {
+ if (replaceWithCallToVeclib(TLI, TTI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
}
@@ -243,7 +293,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
FunctionAnalysisManager &AM) {
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- auto Changed = runImpl(TLI, F);
+ const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+ auto Changed = runImpl(TLI, TTI, F);
if (Changed) {
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
<< NumCallsReplaced << "\n");
@@ -251,6 +302,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
+ PA.preserve<TargetIRAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<LoopAccessAnalysis>();
PA.preserve<DemandedBitsAnalysis>();
@@ -268,13 +320,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- return runImpl(TLI, F);
+ const TargetTransformInfo &TTI =
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ return runImpl(TLI, TTI, F);
}
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
+ AU.addPreserved<TargetTransformInfoWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index 758df0493cc504..d3e1ae338f2caa 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index f408df570fdc00..69b16b02adaa2d 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: @frem_f64(
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
e476a0e
to
2fbe3e0
Compare
Pass replace-with-veclib only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.
2fbe3e0
to
2a85ed1
Compare
/// original instruction \p I. | ||
static bool isVeclibCallSlower(const TargetLibraryInfo &TLI, | ||
const TargetTransformInfo &TTI, Instruction &I, | ||
VectorType *VectorTy, CallInst *CI, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just pass the instruction and re-cast it here (possibly directly to IntrinsicInst?) instead of passing both.
VectorType *VectorTy, CallInst *CI, | ||
Function *TLIFunc) { | ||
SmallVector<Type *, 4> OpTypes; | ||
for (auto &Op : CI ? CI->args() : I.operands()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should only need the I.operands()
; also consider using Value *
instead of auto &
SmallVector<const Value *, 4> OpValues(I.operand_values()); | ||
DefaultCost = TTI.getArithmeticInstrCost( | ||
I.getOpcode(), VectorTy, CostKind, | ||
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the arguments after CostKind actually needed? (They have defaults)
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 { | |||
define <2 x double> @frem_f64(<2 x double> %in) { | |||
; CHECK-LABEL: define <2 x double> @frem_f64 | |||
; CHECK-SAME: (<2 x double> [[IN:%.*]]) { | |||
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a good idea to add a flag to ReplaceWithVeclib.cpp to override the cost for testing purposes. Then filter for call or frem with the autogenerator, and add a second runline using the flag to make sure we still perform the transformation.
This is no longer needed. |
Pass
replace-with-veclib
only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.