Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TLI] Fix replace-with-veclib crash with invalid arguments #77112

Merged
merged 9 commits into from
Jan 12, 2024

Conversation

paschalis-mpeis
Copy link
Member

@paschalis-mpeis paschalis-mpeis commented Jan 5, 2024

Fix a crash of replace-with-veclib pass, when the arguments of the TLI mapping
do not match the original call. After this patch, it will simply ignores such cases.

REVERTED:

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Paschalis Mpeis (paschalis-mpeis)

Changes

Fix a crash of replace-with-veclib pass, when the arguments of the TLI mapping
do not match the original call. After this patch, it will simply ignores such cases.

Stacked PR:

  • (to be updated)

Patch is 24.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77112.diff

6 Files Affected:

  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+109-69)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll (+41-1)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll (+19-1)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll (+19-1)
  • (modified) llvm/unittests/Analysis/CMakeLists.txt (+1)
  • (added) llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp (+85)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 893aa4a91828d3..92f2d006fd79c2 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
-// with vector operands) with matching calls to functions from a vector
-// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
+// Replaces LLVM IR instructions with vector operands (i.e., the frem
+// instruction or calls to LLVM intrinsics) with matching calls to functions
+// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
 //
 //===----------------------------------------------------------------------===//
 
@@ -69,88 +69,100 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
   return TLIFunc;
 }
 
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the instruction \p I with a call to the corresponding function from
+/// the vector library (\p TLIVecFunc).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
                                    Function *TLIVecFunc) {
-  IRBuilder<> IRBuilder(&CalltoReplace);
-  SmallVector<Value *> Args(CalltoReplace.args());
+  IRBuilder<> IRBuilder(&I);
+  auto *CI = dyn_cast<CallInst>(&I);
+  SmallVector<Value *> Args(CI ? CI->args() : I.operands());
   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
-    auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
-                                   Info.Shape.VF);
+    auto *MaskTy =
+        VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
     Args.insert(Args.begin() + OptMaskpos.value(),
                 Constant::getAllOnesValue(MaskTy));
   }
 
-  // Preserve the operand bundles.
+  // If it is a call instruction, preserve the operand bundles.
   SmallVector<OperandBundleDef, 1> OpBundles;
-  CalltoReplace.getOperandBundlesAsDefs(OpBundles);
-  CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
-  CalltoReplace.replaceAllUsesWith(Replacement);
+  if (CI)
+    CI->getOperandBundlesAsDefs(OpBundles);
+
+  auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+  I.replaceAllUsesWith(Replacement);
   // Preserve fast math flags for FP math.
   if (isa<FPMathOperator>(Replacement))
-    Replacement->copyFastMathFlags(&CalltoReplace);
+    Replacement->copyFastMathFlags(&I);
 }
 
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the frem instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
-                                    CallInst &CallToReplace) {
-  if (!CallToReplace.getCalledFunction())
-    return false;
-
-  auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
-  // Replacement is only performed for intrinsic functions.
-  if (IntrinsicID == Intrinsic::not_intrinsic)
-    return false;
+                                    Instruction &I) {
+  // At the moment VFABI assumes the return type is always widened unless it is
+  // a void type.
+  auto *VTy = dyn_cast<VectorType>(I.getType());
+  ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
 
-  // Compute arguments types of the corresponding scalar call. Additionally
-  // checks if in the vector call, all vector operands have the same EC.
-  ElementCount VF = ElementCount::getFixed(0);
-  SmallVector<Type *> ScalarArgTypes;
-  for (auto Arg : enumerate(CallToReplace.args())) {
-    auto *ArgTy = Arg.value()->getType();
-    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarArgTypes.push_back(ArgTy);
-    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
-      ScalarArgTypes.push_back(ArgTy->getScalarType());
-      // Disallow vector arguments with different VFs. When processing the first
-      // vector argument, store it's VF, and for the rest ensure that they match
-      // it.
-      if (VF.isZero())
-        VF = VectorArgTy->getElementCount();
-      else if (VF != VectorArgTy->getElementCount())
+  // Compute the argument types of the corresponding scalar call and the scalar
+  // function name. For calls, it additionally finds the function to replace
+  // and checks that all vector operands match the previously found EC.
+  SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
+  std::string ScalarName;
+  Function *FuncToReplace = nullptr;
+  auto *CI = dyn_cast<CallInst>(&I);
+  if (CI) {
+    FuncToReplace = CI->getCalledFunction();
+    Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
+    assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
+    for (auto Arg : enumerate(CI->args())) {
+      auto *ArgTy = Arg.value()->getType();
+      OrigArgTypes.push_back(ArgTy);
+      if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+        ScalarArgTypes.push_back(ArgTy);
+      } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+        ScalarArgTypes.push_back(VectorArgTy->getElementType());
+        // When return type is void, set EC to the first vector argument, and
+        // disallow vector arguments with different ECs.
+        if (EC.isZero())
+          EC = VectorArgTy->getElementCount();
+        else if (EC != VectorArgTy->getElementCount())
+          return false;
+      } else
+        // Exit when it is supposed to be a vector argument but it isn't.
         return false;
-    } else
-      // Exit when it is supposed to be a vector argument but it isn't.
+    }
+    // Try to reconstruct the name for the scalar version of the instruction,
+    // using scalar argument types.
+    ScalarName = Intrinsic::isOverloaded(IID)
+                     ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
+                     : Intrinsic::getName(IID).str();
+  } else {
+    assert(VTy && "Return type must be a vector");
+    auto *ScalarTy = VTy->getScalarType();
+    LibFunc Func;
+    if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
       return false;
+    ScalarName = TLI.getName(Func);
+    ScalarArgTypes = {ScalarTy, ScalarTy};
   }
 
-  // Try to reconstruct the name for the scalar version of this intrinsic using
-  // the intrinsic ID and the argument types converted to scalar above.
-  std::string ScalarName =
-      (Intrinsic::isOverloaded(IntrinsicID)
-           ? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
-                                CallToReplace.getModule())
-           : Intrinsic::getName(IntrinsicID).str());
-
   // Try to find the mapping for the scalar version of this intrinsic and the
   // exact vector width of the call operands in the TargetLibraryInfo. First,
   // check with a non-masked variant, and if that fails try with a masked one.
   const VecDesc *VD =
-      TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ false);
-  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ true)))
+      TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
+  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
     return false;
 
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
-                    << "` and vector width " << VF << " to: `"
+                    << "` and vector width " << EC << " to: `"
                     << VD->getVectorFnName() << "`.\n");
 
   // Replace the call to the intrinsic with a call to the vector library
   // function.
-  Type *ScalarRetTy = CallToReplace.getType()->getScalarType();
+  Type *ScalarRetTy = I.getType()->getScalarType();
   FunctionType *ScalarFTy =
       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
   const std::string MangledName = VD->getVectorFunctionABIVariantString();
@@ -162,27 +174,55 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   if (!VectorFTy)
     return false;
 
-  Function *FuncToReplace = CallToReplace.getCalledFunction();
-  Function *TLIFunc = getTLIFunction(CallToReplace.getModule(), VectorFTy,
+  Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
-  replaceWithTLIFunction(CallToReplace, *OptInfo, TLIFunc);
 
-  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
-                    << FuncToReplace->getName() << "` with call to `"
-                    << TLIFunc->getName() << "`.\n");
+  // For calls, bail out when their arguments do not match with the TLI mapping.
+  if (CI) {
+    int IdxNonPred = 0;
+    for (auto [OrigTy, VFParam] :
+         zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+      if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+        continue;
+      ++IdxNonPred;
+      if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE
+                          << ": Will not replace: wrong type at index: "
+                          << IdxNonPred << ": " << *OrigTy << "\n");
+        return false;
+      }
+    }
+  }
+
+  replaceWithTLIFunction(I, *OptInfo, TLIFunc);
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
+                    << "` with call to `" << TLIFunc->getName() << "`.\n");
   ++NumCallsReplaced;
   return true;
 }
 
+/// Supported instruction \p I must be a vectorized frem or a call to an
+/// intrinsic that returns either void or a vector.
+static bool isSupportedInstruction(Instruction *I) {
+  Type *Ty = I->getType();
+  if (auto *CI = dyn_cast<CallInst>(I))
+    return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
+           CI->getCalledFunction()->getIntrinsicID() !=
+               Intrinsic::not_intrinsic;
+  if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
+    return true;
+  return false;
+}
+
 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
   bool Changed = false;
-  SmallVector<CallInst *> ReplacedCalls;
+  SmallVector<Instruction *> ReplacedCalls;
   for (auto &I : instructions(F)) {
-    if (auto *CI = dyn_cast<CallInst>(&I)) {
-      if (replaceWithCallToVeclib(TLI, *CI)) {
-        ReplacedCalls.push_back(CI);
-        Changed = true;
-      }
+    if (!isSupportedInstruction(&I))
+      continue;
+    if (replaceWithCallToVeclib(TLI, I)) {
+      ReplacedCalls.push_back(&I);
+      Changed = true;
     }
   }
   // Erase the calls to the intrinsics that have been replaced
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
similarity index 91%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index d41870ec6e7915..4480a90a2728d3 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [32 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [36 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vfmodq_f64, ptr @armpl_vfmodq_f32, ptr @armpl_svfmod_f64_x, ptr @armpl_svfmod_f32_x], section "llvm.metadata"
 ;.
 define <2 x double> @llvm_cos_f64(<2 x double> %in) {
 ; CHECK-LABEL: define <2 x double> @llvm_cos_f64
@@ -424,6 +424,46 @@ define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vsca
   ret <vscale x 4 x float> %1
 }
 
+define <2 x double> @frem_f64(<2 x double> %in) {
+; CHECK-LABEL: define <2 x double> @frem_f64
+; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT:    ret <2 x double> [[TMP1]]
+;
+  %1= frem <2 x double> %in, %in
+  ret <2 x double> %1
+}
+
+define <4 x float> @frem_f32(<4 x float> %in) {
+; CHECK-LABEL: define <4 x float> @frem_f32
+; CHECK-SAME: (<4 x float> [[IN:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @armpl_vfmodq_f32(<4 x float> [[IN]], <4 x float> [[IN]])
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
+;
+  %1= frem <4 x float> %in, %in
+  ret <4 x float> %1
+}
+
+define <vscale x 2 x double> @frem_vscale_f64(<vscale x 2 x double> %in) #0 {
+; CHECK-LABEL: define <vscale x 2 x double> @frem_vscale_f64
+; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x double> @armpl_svfmod_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
+;
+  %1= frem <vscale x 2 x double> %in, %in
+  ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_vscale_f32(<vscale x 4 x float> %in) #0 {
+; CHECK-LABEL: define <vscale x 4 x float> @frem_vscale_f32
+; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @armpl_svfmod_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
+;
+  %1= frem <vscale x 4 x float> %in, %in
+  ret <vscale x 4 x float> %1
+}
+
 attributes #0 = { "target-features"="+sve" }
 ;.
 ; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
similarity index 95%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
index c2ff6014bc6944..590dd9effac0ea 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
@@ -4,7 +4,7 @@
 target triple = "aarch64-unknown-linux-gnu"
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf, ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata"
 ;.
 define <vscale x 2 x double> @llvm_ceil_vscale_f64(<vscale x 2 x double> %in) {
 ; CHECK-LABEL: @llvm_ceil_vscale_f64(
@@ -384,6 +384,24 @@ define <vscale x 4 x float> @llvm_trunc_vscale_f32(<vscale x 4 x float> %in) {
   ret <vscale x 4 x float> %1
 }
 
+define <vscale x 2 x double> @frem_f64(<vscale x 2 x double> %in) {
+; CHECK-LABEL: @frem_f64(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x double> @_ZGVsMxvv_fmod(<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
+;
+  %1= frem <vscale x 2 x double> %in, %in
+  ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_f32(<vscale x 4 x float> %in) {
+; CHECK-LABEL: @frem_f32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @_ZGVsMxvv_fmodf(<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
+;
+  %1= frem <vscale x 4 x float> %in, %in
+  ret <vscale x 4 x float> %1
+}
+
 declare <vscale x 2 x double> @llvm.ceil.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.ceil.nxv4f32(<vscale x 4 x float>)
 declare <vscale x 2 x double> @llvm.copysign.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
similarity index 95%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index be247de368056e..865a46009b205f 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -4,7 +4,7 @@
 target triple = "aarch64-unknown-linux-gnu"
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf, ptr @_ZGVnN2vv_fmod, ptr @_ZGVnN4vv_fmodf], section "llvm.metadata"
 ;.
 define <2 ...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Paschalis Mpeis (paschalis-mpeis)

Changes

Fix a crash of replace-with-veclib pass, when the arguments of the TLI mapping
do not match the original call. After this patch, it will simply ignores such cases.

Stacked PR:

  • (to be updated)

Patch is 24.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77112.diff

6 Files Affected:

  • (modified) llvm/lib/CodeGen/ReplaceWithVeclib.cpp (+109-69)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll (+41-1)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll (+19-1)
  • (renamed) llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll (+19-1)
  • (modified) llvm/unittests/Analysis/CMakeLists.txt (+1)
  • (added) llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp (+85)
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 893aa4a91828d3..92f2d006fd79c2 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
-// with vector operands) with matching calls to functions from a vector
-// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
+// Replaces LLVM IR instructions with vector operands (i.e., the frem
+// instruction or calls to LLVM intrinsics) with matching calls to functions
+// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
 //
 //===----------------------------------------------------------------------===//
 
@@ -69,88 +69,100 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
   return TLIFunc;
 }
 
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the instruction \p I with a call to the corresponding function from
+/// the vector library (\p TLIVecFunc).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
                                    Function *TLIVecFunc) {
-  IRBuilder<> IRBuilder(&CalltoReplace);
-  SmallVector<Value *> Args(CalltoReplace.args());
+  IRBuilder<> IRBuilder(&I);
+  auto *CI = dyn_cast<CallInst>(&I);
+  SmallVector<Value *> Args(CI ? CI->args() : I.operands());
   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
-    auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
-                                   Info.Shape.VF);
+    auto *MaskTy =
+        VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
     Args.insert(Args.begin() + OptMaskpos.value(),
                 Constant::getAllOnesValue(MaskTy));
   }
 
-  // Preserve the operand bundles.
+  // If it is a call instruction, preserve the operand bundles.
   SmallVector<OperandBundleDef, 1> OpBundles;
-  CalltoReplace.getOperandBundlesAsDefs(OpBundles);
-  CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
-  CalltoReplace.replaceAllUsesWith(Replacement);
+  if (CI)
+    CI->getOperandBundlesAsDefs(OpBundles);
+
+  auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+  I.replaceAllUsesWith(Replacement);
   // Preserve fast math flags for FP math.
   if (isa<FPMathOperator>(Replacement))
-    Replacement->copyFastMathFlags(&CalltoReplace);
+    Replacement->copyFastMathFlags(&I);
 }
 
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the frem instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
-                                    CallInst &CallToReplace) {
-  if (!CallToReplace.getCalledFunction())
-    return false;
-
-  auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
-  // Replacement is only performed for intrinsic functions.
-  if (IntrinsicID == Intrinsic::not_intrinsic)
-    return false;
+                                    Instruction &I) {
+  // At the moment VFABI assumes the return type is always widened unless it is
+  // a void type.
+  auto *VTy = dyn_cast<VectorType>(I.getType());
+  ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
 
-  // Compute arguments types of the corresponding scalar call. Additionally
-  // checks if in the vector call, all vector operands have the same EC.
-  ElementCount VF = ElementCount::getFixed(0);
-  SmallVector<Type *> ScalarArgTypes;
-  for (auto Arg : enumerate(CallToReplace.args())) {
-    auto *ArgTy = Arg.value()->getType();
-    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarArgTypes.push_back(ArgTy);
-    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
-      ScalarArgTypes.push_back(ArgTy->getScalarType());
-      // Disallow vector arguments with different VFs. When processing the first
-      // vector argument, store it's VF, and for the rest ensure that they match
-      // it.
-      if (VF.isZero())
-        VF = VectorArgTy->getElementCount();
-      else if (VF != VectorArgTy->getElementCount())
+  // Compute the argument types of the corresponding scalar call and the scalar
+  // function name. For calls, it additionally finds the function to replace
+  // and checks that all vector operands match the previously found EC.
+  SmallVector<Type *, 8> ScalarArgTypes, OrigArgTypes;
+  std::string ScalarName;
+  Function *FuncToReplace = nullptr;
+  auto *CI = dyn_cast<CallInst>(&I);
+  if (CI) {
+    FuncToReplace = CI->getCalledFunction();
+    Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
+    assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
+    for (auto Arg : enumerate(CI->args())) {
+      auto *ArgTy = Arg.value()->getType();
+      OrigArgTypes.push_back(ArgTy);
+      if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+        ScalarArgTypes.push_back(ArgTy);
+      } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+        ScalarArgTypes.push_back(VectorArgTy->getElementType());
+        // When return type is void, set EC to the first vector argument, and
+        // disallow vector arguments with different ECs.
+        if (EC.isZero())
+          EC = VectorArgTy->getElementCount();
+        else if (EC != VectorArgTy->getElementCount())
+          return false;
+      } else
+        // Exit when it is supposed to be a vector argument but it isn't.
         return false;
-    } else
-      // Exit when it is supposed to be a vector argument but it isn't.
+    }
+    // Try to reconstruct the name for the scalar version of the instruction,
+    // using scalar argument types.
+    ScalarName = Intrinsic::isOverloaded(IID)
+                     ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
+                     : Intrinsic::getName(IID).str();
+  } else {
+    assert(VTy && "Return type must be a vector");
+    auto *ScalarTy = VTy->getScalarType();
+    LibFunc Func;
+    if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
       return false;
+    ScalarName = TLI.getName(Func);
+    ScalarArgTypes = {ScalarTy, ScalarTy};
   }
 
-  // Try to reconstruct the name for the scalar version of this intrinsic using
-  // the intrinsic ID and the argument types converted to scalar above.
-  std::string ScalarName =
-      (Intrinsic::isOverloaded(IntrinsicID)
-           ? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
-                                CallToReplace.getModule())
-           : Intrinsic::getName(IntrinsicID).str());
-
   // Try to find the mapping for the scalar version of this intrinsic and the
   // exact vector width of the call operands in the TargetLibraryInfo. First,
   // check with a non-masked variant, and if that fails try with a masked one.
   const VecDesc *VD =
-      TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ false);
-  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, /*Masked*/ true)))
+      TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
+  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
     return false;
 
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
-                    << "` and vector width " << VF << " to: `"
+                    << "` and vector width " << EC << " to: `"
                     << VD->getVectorFnName() << "`.\n");
 
   // Replace the call to the intrinsic with a call to the vector library
   // function.
-  Type *ScalarRetTy = CallToReplace.getType()->getScalarType();
+  Type *ScalarRetTy = I.getType()->getScalarType();
   FunctionType *ScalarFTy =
       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
   const std::string MangledName = VD->getVectorFunctionABIVariantString();
@@ -162,27 +174,55 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
   if (!VectorFTy)
     return false;
 
-  Function *FuncToReplace = CallToReplace.getCalledFunction();
-  Function *TLIFunc = getTLIFunction(CallToReplace.getModule(), VectorFTy,
+  Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
                                      VD->getVectorFnName(), FuncToReplace);
-  replaceWithTLIFunction(CallToReplace, *OptInfo, TLIFunc);
 
-  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
-                    << FuncToReplace->getName() << "` with call to `"
-                    << TLIFunc->getName() << "`.\n");
+  // For calls, bail out when their arguments do not match with the TLI mapping.
+  if (CI) {
+    int IdxNonPred = 0;
+    for (auto [OrigTy, VFParam] :
+         zip(OrigArgTypes, OptInfo->Shape.Parameters)) {
+      if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
+        continue;
+      ++IdxNonPred;
+      if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
+        LLVM_DEBUG(dbgs() << DEBUG_TYPE
+                          << ": Will not replace: wrong type at index: "
+                          << IdxNonPred << ": " << *OrigTy << "\n");
+        return false;
+      }
+    }
+  }
+
+  replaceWithTLIFunction(I, *OptInfo, TLIFunc);
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
+                    << "` with call to `" << TLIFunc->getName() << "`.\n");
   ++NumCallsReplaced;
   return true;
 }
 
+/// Supported instruction \p I must be a vectorized frem or a call to an
+/// intrinsic that returns either void or a vector.
+static bool isSupportedInstruction(Instruction *I) {
+  Type *Ty = I->getType();
+  if (auto *CI = dyn_cast<CallInst>(I))
+    return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
+           CI->getCalledFunction()->getIntrinsicID() !=
+               Intrinsic::not_intrinsic;
+  if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
+    return true;
+  return false;
+}
+
 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
   bool Changed = false;
-  SmallVector<CallInst *> ReplacedCalls;
+  SmallVector<Instruction *> ReplacedCalls;
   for (auto &I : instructions(F)) {
-    if (auto *CI = dyn_cast<CallInst>(&I)) {
-      if (replaceWithCallToVeclib(TLI, *CI)) {
-        ReplacedCalls.push_back(CI);
-        Changed = true;
-      }
+    if (!isSupportedInstruction(&I))
+      continue;
+    if (replaceWithCallToVeclib(TLI, I)) {
+      ReplacedCalls.push_back(&I);
+      Changed = true;
     }
   }
   // Erase the calls to the intrinsics that have been replaced
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
similarity index 91%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index d41870ec6e7915..4480a90a2728d3 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -15,7 +15,7 @@ declare <vscale x 2 x double> @llvm.cos.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.cos.nxv4f32(<vscale x 4 x float>)
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [32 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [36 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_svcos_f64_x, ptr @armpl_svcos_f32_x, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_svsin_f64_x, ptr @armpl_svsin_f32_x, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_svexp_f64_x, ptr @armpl_svexp_f32_x, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_svexp2_f64_x, ptr @armpl_svexp2_f32_x, ptr @armpl_vexp10q_f64, ptr @armpl_vexp10q_f32, ptr @armpl_svexp10_f64_x, ptr @armpl_svexp10_f32_x, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_svlog_f64_x, ptr @armpl_svlog_f32_x, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_svlog2_f64_x, ptr @armpl_svlog2_f32_x, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svlog10_f64_x, ptr @armpl_svlog10_f32_x, ptr @armpl_vfmodq_f64, ptr @armpl_vfmodq_f32, ptr @armpl_svfmod_f64_x, ptr @armpl_svfmod_f32_x], section "llvm.metadata"
 ;.
 define <2 x double> @llvm_cos_f64(<2 x double> %in) {
 ; CHECK-LABEL: define <2 x double> @llvm_cos_f64
@@ -424,6 +424,46 @@ define <vscale x 4 x float> @llvm_pow_vscale_f32(<vscale x 4 x float> %in, <vsca
   ret <vscale x 4 x float> %1
 }
 
+define <2 x double> @frem_f64(<2 x double> %in) {
+; CHECK-LABEL: define <2 x double> @frem_f64
+; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT:    ret <2 x double> [[TMP1]]
+;
+  %1= frem <2 x double> %in, %in
+  ret <2 x double> %1
+}
+
+define <4 x float> @frem_f32(<4 x float> %in) {
+; CHECK-LABEL: define <4 x float> @frem_f32
+; CHECK-SAME: (<4 x float> [[IN:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @armpl_vfmodq_f32(<4 x float> [[IN]], <4 x float> [[IN]])
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
+;
+  %1= frem <4 x float> %in, %in
+  ret <4 x float> %1
+}
+
+define <vscale x 2 x double> @frem_vscale_f64(<vscale x 2 x double> %in) #0 {
+; CHECK-LABEL: define <vscale x 2 x double> @frem_vscale_f64
+; CHECK-SAME: (<vscale x 2 x double> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x double> @armpl_svfmod_f64_x(<vscale x 2 x double> [[IN]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
+;
+  %1= frem <vscale x 2 x double> %in, %in
+  ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_vscale_f32(<vscale x 4 x float> %in) #0 {
+; CHECK-LABEL: define <vscale x 4 x float> @frem_vscale_f32
+; CHECK-SAME: (<vscale x 4 x float> [[IN:%.*]]) #[[ATTR1]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @armpl_svfmod_f32_x(<vscale x 4 x float> [[IN]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
+;
+  %1= frem <vscale x 4 x float> %in, %in
+  ret <vscale x 4 x float> %1
+}
+
 attributes #0 = { "target-features"="+sve" }
 ;.
 ; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
similarity index 95%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
index c2ff6014bc6944..590dd9effac0ea 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef-scalable.ll
@@ -4,7 +4,7 @@
 target triple = "aarch64-unknown-linux-gnu"
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf, ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata"
 ;.
 define <vscale x 2 x double> @llvm_ceil_vscale_f64(<vscale x 2 x double> %in) {
 ; CHECK-LABEL: @llvm_ceil_vscale_f64(
@@ -384,6 +384,24 @@ define <vscale x 4 x float> @llvm_trunc_vscale_f32(<vscale x 4 x float> %in) {
   ret <vscale x 4 x float> %1
 }
 
+define <vscale x 2 x double> @frem_f64(<vscale x 2 x double> %in) {
+; CHECK-LABEL: @frem_f64(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x double> @_ZGVsMxvv_fmod(<vscale x 2 x double> [[IN:%.*]], <vscale x 2 x double> [[IN]], <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x double> [[TMP1]]
+;
+  %1= frem <vscale x 2 x double> %in, %in
+  ret <vscale x 2 x double> %1
+}
+
+define <vscale x 4 x float> @frem_f32(<vscale x 4 x float> %in) {
+; CHECK-LABEL: @frem_f32(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x float> @_ZGVsMxvv_fmodf(<vscale x 4 x float> [[IN:%.*]], <vscale x 4 x float> [[IN]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 4 x float> [[TMP1]]
+;
+  %1= frem <vscale x 4 x float> %in, %in
+  ret <vscale x 4 x float> %1
+}
+
 declare <vscale x 2 x double> @llvm.ceil.nxv2f64(<vscale x 2 x double>)
 declare <vscale x 4 x float> @llvm.ceil.nxv4f32(<vscale x 4 x float>)
 declare <vscale x 2 x double> @llvm.copysign.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>)
diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
similarity index 95%
rename from llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
rename to llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index be247de368056e..865a46009b205f 100644
--- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -4,7 +4,7 @@
 target triple = "aarch64-unknown-linux-gnu"
 
 ;.
-; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf], section "llvm.metadata"
+; CHECK: @llvm.compiler.used = appending global [18 x ptr] [ptr @_ZGVnN2v_cos, ptr @_ZGVnN4v_cosf, ptr @_ZGVnN2v_exp, ptr @_ZGVnN4v_expf, ptr @_ZGVnN2v_exp2, ptr @_ZGVnN4v_exp2f, ptr @_ZGVnN2v_exp10, ptr @_ZGVnN4v_exp10f, ptr @_ZGVnN2v_log, ptr @_ZGVnN4v_logf, ptr @_ZGVnN2v_log10, ptr @_ZGVnN4v_log10f, ptr @_ZGVnN2v_log2, ptr @_ZGVnN4v_log2f, ptr @_ZGVnN2v_sin, ptr @_ZGVnN4v_sinf, ptr @_ZGVnN2vv_fmod, ptr @_ZGVnN4vv_fmodf], section "llvm.metadata"
 ;.
 define <2 ...
[truncated]

replace-with-veclib used to crash when the arguments of the TLI mapping
did not match the arguments of the mapping. Now, it simply ignores such
cases.
@paschalis-mpeis paschalis-mpeis force-pushed the users/paschalis-mpeis/fix-replace-veclib-crash branch from af5c683 to 9dfc955 Compare January 8, 2024 09:29
@paschalis-mpeis
Copy link
Member Author

Rebased to main (PR is no longer stacked).

llvm/lib/CodeGen/ReplaceWithVeclib.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/ReplaceWithVeclib.cpp Show resolved Hide resolved
llvm/lib/CodeGen/ReplaceWithVeclib.cpp Outdated Show resolved Hide resolved
llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/ReplaceWithVeclib.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@labrinea labrinea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright now that I understand the complications I am happy with the unittest to stay as is.

Copy link
Contributor

@mgabka mgabka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving the test!

llvm/unittests/Analysis/ReplaceWithVecLibTest.cpp Outdated Show resolved Hide resolved
@mgabka
Copy link
Contributor

mgabka commented Jan 12, 2024

thanks for addressing the remaining, review comments it LGTM

@paschalis-mpeis paschalis-mpeis merged commit 9fdc568 into main Jan 12, 2024
3 of 4 checks passed
@paschalis-mpeis paschalis-mpeis deleted the users/paschalis-mpeis/fix-replace-veclib-crash branch January 12, 2024 15:19
paschalis-mpeis added a commit that referenced this pull request Jan 12, 2024
…77112)"

This reverts commit 9fdc568,
as it linker crashes on some platforms.
@paschalis-mpeis
Copy link
Member Author

paschalis-mpeis commented Jan 12, 2024

REVERTED:
Causes some linker crash on builder: llvm-nvptx64-nvidia-ubuntu (more..

Reapplied by:

paschalis-mpeis added a commit that referenced this pull request Jan 12, 2024
…77112)

Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.

Test require assertions as it accesses programmatically the debug log.

NOTE: Originally submitted by commit 9fdc568, which was reverted by
a300b24, as it was causing some linking issues:
https://lab.llvm.org/buildbot/#/builders/234/builds/17734
paschalis-mpeis added a commit that referenced this pull request Jan 16, 2024
…77945)

Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.

Test require assertions as it accesses programmatically the debug log.

Reapplies reverted PR #77112
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.

Test require assertions as it accesses programmatically the debug log.
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…lvm#77112)"

This reverts commit 9fdc568,
as it linker crashes on some platforms.
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…lvm#77945)

Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI
mapping do not match the original call.
Now, it simply ignores such cases.

Test require assertions as it accesses programmatically the debug log.

Reapplies reverted PR llvm#77112
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants