Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TLI] ReplaceWithVecLib pass uses CostModel #78688

Conversation

paschalis-mpeis
Copy link
Member

Pass replace-with-veclib only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Paschalis Mpeis (paschalis-mpeis)

Changes

Pass replace-with-veclib only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.


Full diff: https://github.com/llvm/llvm-project/pull/78688.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+68-12)
  • (modified) llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll (+1-1)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 7b0215535a92c8..c57156c00a74e8 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -9,6 +9,8 @@
 // Replaces LLVM IR instructions with vector operands (i.e., the frem
 // instruction or calls to LLVM intrinsics) with matching calls to functions
 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
+// This happens only when the cost of calling the vector library is not found to
+// be more than the cost of the original instruction.
 //
 //===----------------------------------------------------------------------===//
 
@@ -20,11 +22,15 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/InstructionCost.h"
 #include "llvm/Support/TypeSize.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
@@ -95,15 +101,54 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
     Replacement->copyFastMathFlags(&I);
 }
 
+/// Returns whether the vector library call \p TLIFunc costs more than the
+/// original instruction \p I.
+static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
+                               const TargetTransformInfo &TTI, Instruction &I,
+                               VectorType *VectorTy, CallInst *CI,
+                               Function *TLIFunc) {
+  SmallVector<Type *, 4> OpTypes;
+  for (auto &Op : CI ? CI->args() : I.operands())
+    OpTypes.push_back(Op->getType());
+
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  InstructionCost DefaultCost;
+  if (CI) {
+    FastMathFlags FMF;
+    if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
+      FMF = FPMO->getFastMathFlags();
+
+    SmallVector<const Value *> Args(CI->args());
+    IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
+                                      OpTypes, FMF,
+                                      dyn_cast<IntrinsicInst>(CI));
+    DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+  } else {
+    assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
+    auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
+    SmallVector<const Value *, 4> OpValues(I.operand_values());
+    DefaultCost = TTI.getArithmeticInstrCost(
+        I.getOpcode(), VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Op2Info, OpValues, &I);
+  }
+
+  InstructionCost VecLibCost =
+      TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
+  return VecLibCost > DefaultCost;
+}
+
 /// Returns true when successfully replaced \p I with a suitable function taking
-/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem instruction.
+/// vector arguments, based on available mappings in the \p TLI and costs.
+/// Currently only works when \p I is a call to vectorized intrinsic or the frem
+/// instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
+                                    const TargetTransformInfo &TTI,
                                     Instruction &I) {
   // At the moment VFABI assumes the return type is always widened unless it is
   // a void type.
-  auto *VTy = dyn_cast<VectorType>(I.getType());
-  ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+  auto *VectorTy = dyn_cast<VectorType>(I.getType());
+  ElementCount EC(VectorTy ? VectorTy->getElementCount() : ElementCount::getFixed(0));
 
   // Compute the argument types of the corresponding scalar call and the scalar
   // function name. For calls, it additionally finds the function to replace
@@ -124,9 +169,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
         ScalarArgTypes.push_back(VectorArgTy->getElementType());
         // When return type is void, set EC to the first vector argument, and
         // disallow vector arguments with different ECs.
-        if (EC.isZero())
+        if (EC.isZero()) {
           EC = VectorArgTy->getElementCount();
-        else if (EC != VectorArgTy->getElementCount())
+          VectorTy = VectorArgTy;
+        } else if (EC != VectorArgTy->getElementCount())
           return false;
       } else
         // Exit when it is supposed to be a vector argument but it isn't.
@@ -138,8 +184,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
                      ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
                      : Intrinsic::getName(IID).str();
   } else {
-    assert(VTy && "Return type must be a vector");
-    auto *ScalarTy = VTy->getScalarType();
+    assert(VectorTy && "Return type must be a vector");
+    auto *ScalarTy = VectorTy->getScalarType();
     LibFunc Func;
     if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
       return false;
@@ -199,6 +245,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
 
+  if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
+    return false;
+
   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
                     << "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -219,13 +268,14 @@ static bool isSupportedInstruction(Instruction *I) {
   return false;
 }
 
-static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
+static bool runImpl(const TargetLibraryInfo &TLI,
+                    const TargetTransformInfo &TTI, Function &F) {
   bool Changed = false;
   SmallVector<Instruction *> ReplacedCalls;
   for (auto &I : instructions(F)) {
     if (!isSupportedInstruction(&I))
       continue;
-    if (replaceWithCallToVeclib(TLI, I)) {
+    if (replaceWithCallToVeclib(TLI, TTI, I)) {
       ReplacedCalls.push_back(&I);
       Changed = true;
     }
@@ -243,7 +293,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
                                          FunctionAnalysisManager &AM) {
   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
-  auto Changed = runImpl(TLI, F);
+  const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+  auto Changed = runImpl(TLI, TTI, F);
   if (Changed) {
     LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
                       << NumCallsReplaced << "\n");
@@ -251,6 +302,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
     PreservedAnalyses PA;
     PA.preserveSet<CFGAnalyses>();
     PA.preserve<TargetLibraryAnalysis>();
+    PA.preserve<TargetIRAnalysis>();
     PA.preserve<ScalarEvolutionAnalysis>();
     PA.preserve<LoopAccessAnalysis>();
     PA.preserve<DemandedBitsAnalysis>();
@@ -268,13 +320,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
   const TargetLibraryInfo &TLI =
       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
-  return runImpl(TLI, F);
+  const TargetTransformInfo &TTI =
+      getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  return runImpl(TLI, TTI, F);
 }
 
 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesCFG();
   AU.addRequired<TargetLibraryInfoWrapperPass>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<TargetLibraryInfoWrapperPass>();
+  AU.addPreserved<TargetTransformInfoWrapperPass>();
   AU.addPreserved<ScalarEvolutionWrapperPass>();
   AU.addPreserved<AAResultsWrapperPass>();
   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index 758df0493cc504..d3e1ae338f2caa 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
 define <2 x double> @frem_f64(<2 x double> %in) {
 ; CHECK-LABEL: define <2 x double> @frem_f64
 ; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1= frem <2 x double> %in, %in
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index f408df570fdc00..69b16b02adaa2d 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
 
 define <2 x double> @frem_f64(<2 x double> %in) {
 ; CHECK-LABEL: @frem_f64(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT:    [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
 ; CHECK-NEXT:    ret <2 x double> [[TMP1]]
 ;
   %1= frem <2 x double> %in, %in

Copy link

github-actions bot commented Jan 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Pass replace-with-veclib only replaces to veclib calls when their cost
is not found to be higher than the cost of the original instruction.
@paschalis-mpeis paschalis-mpeis force-pushed the users/paschalis-mpeis/replace-with-veclib-uses-costmodel branch from 2fbe3e0 to 2a85ed1 Compare January 23, 2024 10:32
/// original instruction \p I.
static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI, Instruction &I,
VectorType *VectorTy, CallInst *CI,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just pass the instruction and re-cast it here (possibly directly to IntrinsicInst?) instead of passing both.

VectorType *VectorTy, CallInst *CI,
Function *TLIFunc) {
SmallVector<Type *, 4> OpTypes;
for (auto &Op : CI ? CI->args() : I.operands())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only need the I.operands(); also consider using Value * instead of auto &

SmallVector<const Value *, 4> OpValues(I.operand_values());
DefaultCost = TTI.getArithmeticInstrCost(
I.getOpcode(), VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the arguments after CostKind actually needed? (They have defaults)

@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a good idea to add a flag to ReplaceWithVeclib.cpp to override the cost for testing purposes. Then filter for call or frem with the autogenerator, and add a second runline using the flag to make sure we still perform the transformation.

@paschalis-mpeis
Copy link
Member Author

This is no longer needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants