Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLP]Fix the cost of the reduction result to the final type. #87528

Conversation

alexey-bataev
Copy link
Member

Need to fix the way the cost is calculated, otherwise wrong cast opcode
can be selected and lead to the over-optimistic vector cost. Plus, need
to take into account reduction type size.

Created using spr 1.3.5
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Need to fix the way the cost is calculated, otherwise wrong cast opcode
can be selected and lead to the over-optimistic vector cost. Plus, need
to take into account reduction type size.


Full diff: https://github.com/llvm/llvm-project/pull/87528.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+5-3)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll (+14-4)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll (+15-10)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll (+1-1)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7928d29d6dfa7d..5522cd4f4b1a60 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9749,11 +9749,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
     if (BWIt != MinBWs.end()) {
       Type *DstTy = Root.Scalars.front()->getType();
       unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
-      if (OriginalSz != BWIt->second.first) {
+      unsigned SrcSz =
+          ReductionBitWidth == 0 ? BWIt->second.first : ReductionBitWidth;
+      if (OriginalSz != SrcSz) {
         unsigned Opcode = Instruction::Trunc;
-        if (OriginalSz < BWIt->second.first)
+        if (OriginalSz > SrcSz)
           Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
-        Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+        Type *SrcTy = IntegerType::get(DstTy->getContext(), SrcSz);
         Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
                                       TTI::CastContextHint::None,
                                       TTI::TCK_RecipThroughput);
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index 500f10659f04cb..1e7eb4a4167242 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -801,10 +801,20 @@ entry:
 define i64 @red_zext_ld_4xi64(ptr %ptr) {
 ; CHECK-LABEL: @red_zext_ld_4xi64(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
-; CHECK-NEXT:    [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TMP2]] to i64
+; CHECK-NEXT:    [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
+; CHECK-NEXT:    [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
+; CHECK-NEXT:    [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
+; CHECK-NEXT:    [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
+; CHECK-NEXT:    [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
+; CHECK-NEXT:    [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
+; CHECK-NEXT:    [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
+; CHECK-NEXT:    [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
+; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
+; CHECK-NEXT:    [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
+; CHECK-NEXT:    [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
 ; CHECK-NEXT:    ret i64 [[TMP3]]
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
index 44738aa1a67479..a8d481a3e28a5c 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
@@ -5,17 +5,22 @@ define i32 @test() {
 ; CHECK-LABEL: define i32 @test() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A_PROMOTED:%.*]] = load i8, ptr null, align 1
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x i8> poison, i8 [[A_PROMOTED]], i32 0
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i8> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = or <4 x i8> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP3]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i16>
-; CHECK-NEXT:    [[TMP6:%.*]] = add <4 x i16> [[TMP5]], <i16 -1, i16 0, i16 0, i16 0>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]])
-; CHECK-NEXT:    [[TMP8:%.*]] = zext i16 [[TMP7]] to i32
+; CHECK-NEXT:    [[DEC_4:%.*]] = add i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_4:%.*]] = zext i8 [[DEC_4]] to i32
+; CHECK-NEXT:    [[SUB_I_4:%.*]] = add nuw nsw i32 [[CONV_I_4]], 0
+; CHECK-NEXT:    [[DEC_5:%.*]] = add i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_5:%.*]] = zext i8 [[DEC_5]] to i32
+; CHECK-NEXT:    [[SUB_I_5:%.*]] = add nuw nsw i32 [[CONV_I_5]], 65535
+; CHECK-NEXT:    [[TMP0:%.*]] = or i32 [[SUB_I_4]], [[SUB_I_5]]
+; CHECK-NEXT:    [[DEC_6:%.*]] = or i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_6:%.*]] = zext i8 [[DEC_6]] to i32
+; CHECK-NEXT:    [[SUB_I_6:%.*]] = add nuw nsw i32 [[CONV_I_6]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[TMP0]], [[SUB_I_6]]
+; CHECK-NEXT:    [[TMP10:%.*]] = or i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_7:%.*]] = zext i8 [[TMP10]] to i32
+; CHECK-NEXT:    [[SUB_I_7:%.*]] = add nuw nsw i32 [[CONV_I_7]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = or i32 [[TMP1]], [[SUB_I_7]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = and i32 [[TMP8]], 65535
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i8> [[TMP4]], i32 3
 ; CHECK-NEXT:    store i8 [[TMP10]], ptr null, align 1
 ; CHECK-NEXT:    [[CALL3:%.*]] = tail call i32 (ptr, ...) null(ptr null, i32 [[TMP9]])
 ; CHECK-NEXT:    ret i32 0
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
index 4acd63078b82ef..4af69dff179e26 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-6 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
+; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-7 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
 
 define void @test(i64 %d.promoted.i) {
 ; CHECK-LABEL: define void @test(

@alexey-bataev
Copy link
Member Author

Ping!

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alexey-bataev alexey-bataev merged commit a612524 into main Apr 7, 2024
7 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpfix-the-cost-of-the-reduction-result-to-the-final-type branch April 7, 2024 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants