Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. #86135

Conversation

alexey-bataev
Copy link
Member

@alexey-bataev alexey-bataev commented Mar 21, 2024

https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations.

Created using spr 1.3.5
@alexey-bataev alexey-bataev changed the title [SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax [SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. Mar 21, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

intrinsics.


Full diff: https://github.com/llvm/llvm-project/pull/86135.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+113-10)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll (+7-5)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll (+4-5)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36b446962c4a63..7f680b7af9b565 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6994,14 +6994,11 @@ bool BoUpSLP::areAllUsersVectorized(
 
 static std::pair<InstructionCost, InstructionCost>
 getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
-                   TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
+                   TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
+                   ArrayRef<Type *> VecTys) {
   Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
   // Calculate the cost of the scalar and vector calls.
-  SmallVector<Type *, 4> VecTys;
-  for (Use &Arg : CI->args())
-    VecTys.push_back(
-        FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
   FastMathFlags FMF;
   if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
     FMF = FPCI->getFastMathFlags();
@@ -9009,7 +9006,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     };
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
-      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+      Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+      SmallVector<Type *> VecTys;
+      for (auto [Idx, Arg] : enumerate(CI->args())) {
+        if (ID != Intrinsic::not_intrinsic) {
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+            VecTys.push_back(Arg->getType());
+            continue;
+          }
+          if (It != MinBWs.end()) {
+            VecTys.push_back(FixedVectorType::get(
+                IntegerType::get(CI->getContext(), It->second.first),
+                VecTy->getNumElements()));
+            continue;
+          }
+        }
+        VecTys.push_back(
+            FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+      }
+      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -12462,7 +12477,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+      SmallVector<Type *> VecTys;
+      for (auto [Idx, Arg] : enumerate(CI->args())) {
+        if (ID != Intrinsic::not_intrinsic) {
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+            VecTys.push_back(Arg->getType());
+            continue;
+          }
+          if (It != MinBWs.end()) {
+            VecTys.push_back(FixedVectorType::get(
+                IntegerType::get(CI->getContext(), It->second.first),
+                VecTy->getNumElements()));
+            continue;
+          }
+        }
+        VecTys.push_back(
+            FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+      }
+      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
 
@@ -12471,14 +12503,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       SmallVector<Type *, 2> TysForDecl;
       // Add return type if intrinsic is overloaded on it.
       if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
-        TysForDecl.push_back(
-            FixedVectorType::get(CI->getType(), E->Scalars.size()));
+        TysForDecl.push_back(VecTy);
+      auto *CEI = cast<CallInst>(VL0);
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
         if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
-          CallInst *CEI = cast<CallInst>(VL0);
           ScalarArg = CEI->getArgOperand(I);
           OpVecs.push_back(CEI->getArgOperand(I));
           if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -12491,6 +12522,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
           return E->VectorizedValue;
         }
+        ScalarArg = CEI->getArgOperand(I);
+        if (cast<VectorType>(OpVec->getType())->getElementType() !=
+            ScalarArg->getType() && It == MinBWs.end()) {
+          auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
+                                              VecTy->getNumElements());
+          OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
+        } else if (It != MinBWs.end()) {
+          OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
+        }
         LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
         OpVecs.push_back(OpVec);
         if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -14195,6 +14235,69 @@ bool BoUpSLP::collectValuesToDemote(
     break;
   }
 
+  case Instruction::Call: {
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    if (auto *IC = dyn_cast<IntrinsicInst>(I)) {
+      Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
+      if (ID == Intrinsic::abs || ID == Intrinsic::smin ||
+          ID == Intrinsic::smax || ID == Intrinsic::umin ||
+          ID == Intrinsic::umax) {
+        InstructionCost BestCost =
+            std::numeric_limits<InstructionCost::CostType>::max();
+        unsigned BestBitWidth = BitWidth;
+        unsigned VF = ITE->Scalars.size();
+        // Choose the best bitwidth based on cost estimations.
+        (void)AttemptCheckBitwidth(
+            [&](unsigned BitWidth, unsigned) {
+              SmallVector<Type *> VecTys;
+              auto *ITy =
+                  IntegerType::get(IC->getContext(), PowerOf2Ceil(BitWidth));
+              for (auto [Idx, Arg] : enumerate(IC->args())) {
+                if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+                  VecTys.push_back(Arg->getType());
+                  continue;
+                }
+                VecTys.push_back(FixedVectorType::get(ITy, VF));
+              }
+              auto VecCallCosts = getVectorCallCosts(
+                  IC, FixedVectorType::get(ITy, VF), TTI, TLI, VecTys);
+              InstructionCost Cost =
+                  std::min(VecCallCosts.first, VecCallCosts.second);
+              if (Cost < BestCost) {
+                BestCost = Cost;
+                BestBitWidth = BitWidth;
+              }
+              return false;
+            },
+            NeedToExit);
+        NeedToExit = false;
+        BitWidth = BestBitWidth;
+        switch (ID) {
+        case Intrinsic::abs:
+          End = 1;
+          if (!ProcessOperands(IC->getArgOperand(0), NeedToExit))
+            return false;
+          break;
+        case Intrinsic::smin:
+        case Intrinsic::smax:
+        case Intrinsic::umin:
+        case Intrinsic::umax:
+          End = 2;
+          if (!ProcessOperands({IC->getArgOperand(0), IC->getArgOperand(1)},
+                               NeedToExit))
+            return false;
+          break;
+        default:
+          llvm_unreachable("Unexpected intrinsic.");
+        }
+        break;
+      }
+    }
+    MaxDepthLevel = 1;
+    return FinalAnalysis();
+  }
+
   // Otherwise, conservatively give up.
   default:
     MaxDepthLevel = 1;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
index a05d4fdd6315b5..9fa88084aaa0af 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
@@ -5,12 +5,14 @@ define void @test() {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
-; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
-; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
+; CHECK-NEXT:    [[ADD:%.*]] = zext i2 [[TMP3]] to i32
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[ADD]], 0
-; CHECK-NEXT:    [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
+; CHECK-NEXT:    [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
 ; CHECK-NEXT:    [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
 ; CHECK-NEXT:    [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
 ; CHECK-NEXT:    [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
index e8b854b7cea6cb..60bec6668d23ba 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
@@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
-; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
-; CHECK-NEXT:    [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
-; CHECK-NEXT:    [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
+; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true)
 ; CHECK-NEXT:    store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
 ; CHECK-NEXT:    ret i32 undef
 ;

Copy link

github-actions bot commented Mar 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

alexey-bataev and others added 2 commits March 21, 2024 15:38
Created using spr 1.3.5
Created using spr 1.3.5
Copy link

✅ With the latest revision this PR passed the Python code formatter.

Created using spr 1.3.5
; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the you have abs

(i8 (trunc (i16 (abs (sext (i8 X)), true)))) will you turn it into (i8 (abs X, true) or (i8 (abs X, false)) or will you leave it alone?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will investigate it better!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the answer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind I see you clear that operand now.

Created using spr 1.3.5
@alexey-bataev
Copy link
Member Author

Ping!

Created using spr 1.3.5
Created using spr 1.3.5
@alexey-bataev
Copy link
Member Author

Ping!

Created using spr 1.3.5
@alexey-bataev
Copy link
Member Author

Ping!

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a couple of minors

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Outdated Show resolved Hide resolved
}
VecTys.push_back(
FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth converting into a helper to avoid duplication?

Created using spr 1.3.5
Created using spr 1.3.5
Created using spr 1.3.5
@alexey-bataev alexey-bataev merged commit 66b5280 into main Apr 5, 2024
3 of 4 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpimprove-minbitwidth-analysis-for-abssminsmaxuminumax branch April 5, 2024 18:29
@mstorsjo
Copy link
Member

mstorsjo commented Apr 6, 2024

This change broke a couple of tests in the libyuv testsuite (https://chromium.googlesource.com/libyuv/libyuv). I’ve reproduced the errors for mingw i686, x86_64 and armv7, but no errors on mingw aarch64. Also reproducible for x86_64 Linux. Build with e.g. cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DUNIT_TEST=ON, might require setting up/fetching gtest manually (and passing -DGTEST_SRC_DIR=$(pwd)/gtest/googletest or similar), building and running ./libyuv_unittest.

@alexey-bataev
Copy link
Member Author

Hi, thanks for letting me know. Feel free to revert the patch for now, will investigate it next week.

mstorsjo added a commit that referenced this pull request Apr 6, 2024
… intrinsics."

This reverts commit 66b5280.

This commit caused miscompilations, breaking tests in the libyuv
testsuite - see
#86135 (comment)
for more details.
@mstorsjo
Copy link
Member

mstorsjo commented Apr 6, 2024

Hi, thanks for letting me know. Feel free to revert the patch for now, will investigate it next week.

Thanks, I pushed a revert in bd9486b.

alexey-bataev added a commit that referenced this pull request Apr 8, 2024
…ics.

https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: #86135
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants