-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. #86135
[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. #86135
Conversation
Created using spr 1.3.5
@llvm/pr-subscribers-llvm-transforms Author: Alexey Bataev (alexey-bataev) Changesintrinsics. Full diff: https://github.com/llvm/llvm-project/pull/86135.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36b446962c4a63..7f680b7af9b565 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6994,14 +6994,11 @@ bool BoUpSLP::areAllUsersVectorized(
static std::pair<InstructionCost, InstructionCost>
getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
- TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
+ TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
+ ArrayRef<Type *> VecTys) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
// Calculate the cost of the scalar and vector calls.
- SmallVector<Type *, 4> VecTys;
- for (Use &Arg : CI->args())
- VecTys.push_back(
- FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
FastMathFlags FMF;
if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
FMF = FPCI->getFastMathFlags();
@@ -9009,7 +9006,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
auto *CI = cast<CallInst>(VL0);
- auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+ SmallVector<Type *> VecTys;
+ for (auto [Idx, Arg] : enumerate(CI->args())) {
+ if (ID != Intrinsic::not_intrinsic) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ if (It != MinBWs.end()) {
+ VecTys.push_back(FixedVectorType::get(
+ IntegerType::get(CI->getContext(), It->second.first),
+ VecTy->getNumElements()));
+ continue;
+ }
+ }
+ VecTys.push_back(
+ FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+ }
+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -12462,7 +12477,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
- auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+ SmallVector<Type *> VecTys;
+ for (auto [Idx, Arg] : enumerate(CI->args())) {
+ if (ID != Intrinsic::not_intrinsic) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ if (It != MinBWs.end()) {
+ VecTys.push_back(FixedVectorType::get(
+ IntegerType::get(CI->getContext(), It->second.first),
+ VecTy->getNumElements()));
+ continue;
+ }
+ }
+ VecTys.push_back(
+ FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+ }
+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
VecCallCosts.first <= VecCallCosts.second;
@@ -12471,14 +12503,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
- TysForDecl.push_back(
- FixedVectorType::get(CI->getType(), E->Scalars.size()));
+ TysForDecl.push_back(VecTy);
+ auto *CEI = cast<CallInst>(VL0);
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
ValueList OpVL;
// Some intrinsics have scalar arguments. This argument should not be
// vectorized.
if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
- CallInst *CEI = cast<CallInst>(VL0);
ScalarArg = CEI->getArgOperand(I);
OpVecs.push_back(CEI->getArgOperand(I));
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -12491,6 +12522,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
+ ScalarArg = CEI->getArgOperand(I);
+ if (cast<VectorType>(OpVec->getType())->getElementType() !=
+ ScalarArg->getType() && It == MinBWs.end()) {
+ auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
+ VecTy->getNumElements());
+ OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
+ } else if (It != MinBWs.end()) {
+ OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
+ }
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
OpVecs.push_back(OpVec);
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -14195,6 +14235,69 @@ bool BoUpSLP::collectValuesToDemote(
break;
}
+ case Instruction::Call: {
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+ return false;
+ if (auto *IC = dyn_cast<IntrinsicInst>(I)) {
+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
+ if (ID == Intrinsic::abs || ID == Intrinsic::smin ||
+ ID == Intrinsic::smax || ID == Intrinsic::umin ||
+ ID == Intrinsic::umax) {
+ InstructionCost BestCost =
+ std::numeric_limits<InstructionCost::CostType>::max();
+ unsigned BestBitWidth = BitWidth;
+ unsigned VF = ITE->Scalars.size();
+ // Choose the best bitwidth based on cost estimations.
+ (void)AttemptCheckBitwidth(
+ [&](unsigned BitWidth, unsigned) {
+ SmallVector<Type *> VecTys;
+ auto *ITy =
+ IntegerType::get(IC->getContext(), PowerOf2Ceil(BitWidth));
+ for (auto [Idx, Arg] : enumerate(IC->args())) {
+ if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+ VecTys.push_back(Arg->getType());
+ continue;
+ }
+ VecTys.push_back(FixedVectorType::get(ITy, VF));
+ }
+ auto VecCallCosts = getVectorCallCosts(
+ IC, FixedVectorType::get(ITy, VF), TTI, TLI, VecTys);
+ InstructionCost Cost =
+ std::min(VecCallCosts.first, VecCallCosts.second);
+ if (Cost < BestCost) {
+ BestCost = Cost;
+ BestBitWidth = BitWidth;
+ }
+ return false;
+ },
+ NeedToExit);
+ NeedToExit = false;
+ BitWidth = BestBitWidth;
+ switch (ID) {
+ case Intrinsic::abs:
+ End = 1;
+ if (!ProcessOperands(IC->getArgOperand(0), NeedToExit))
+ return false;
+ break;
+ case Intrinsic::smin:
+ case Intrinsic::smax:
+ case Intrinsic::umin:
+ case Intrinsic::umax:
+ End = 2;
+ if (!ProcessOperands({IC->getArgOperand(0), IC->getArgOperand(1)},
+ NeedToExit))
+ return false;
+ break;
+ default:
+ llvm_unreachable("Unexpected intrinsic.");
+ }
+ break;
+ }
+ }
+ MaxDepthLevel = 1;
+ return FinalAnalysis();
+ }
+
// Otherwise, conservatively give up.
default:
MaxDepthLevel = 1;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
index a05d4fdd6315b5..9fa88084aaa0af 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
@@ -5,12 +5,14 @@ define void @test() {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
-; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
-; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
-; CHECK-NEXT: [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
+; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
+; CHECK-NEXT: [[ADD:%.*]] = zext i2 [[TMP3]] to i32
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[ADD]], 0
-; CHECK-NEXT: [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
+; CHECK-NEXT: [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
; CHECK-NEXT: [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
; CHECK-NEXT: [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
; CHECK-NEXT: [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
index e8b854b7cea6cb..60bec6668d23ba 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
@@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
+; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
-; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
-; CHECK-NEXT: [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
-; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
+; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
+; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
+; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true)
; CHECK-NEXT: store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
; CHECK-NEXT: ret i32 undef
;
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Created using spr 1.3.5
✅ With the latest revision this PR passed the Python code formatter. |
Created using spr 1.3.5
; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16> | ||
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16> | ||
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]] | ||
; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the you have abs
(i8 (trunc (i16 (abs (sext (i8 X)), true)))) will you turn it into (i8 (abs X, true) or (i8 (abs X, false)) or will you leave it alone?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will investigate it better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the answer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind I see you clear that operand now.
Created using spr 1.3.5
Ping! |
Ping! |
Ping! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of minors
} | ||
VecTys.push_back( | ||
FixedVectorType::get(Arg->getType(), VecTy->getNumElements())); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth converting into a helper to avoid duplication?
Created using spr 1.3.5
Created using spr 1.3.5
Created using spr 1.3.5
This change broke a couple of tests in the libyuv testsuite (https://chromium.googlesource.com/libyuv/libyuv). I’ve reproduced the errors for mingw i686, x86_64 and armv7, but no errors on mingw aarch64. Also reproducible for x86_64 Linux. Build with e.g. |
Hi, thanks for letting me know. Feel free to revert the patch for now, will investigate it next week. |
… intrinsics." This reverts commit 66b5280. This commit caused miscompilations, breaking tests in the libyuv testsuite - see #86135 (comment) for more details.
Thanks, I pushed a revert in bd9486b. |
…ics. https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: #86135
https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations.