Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLP]Fix PR87011: missing sign extension of demoted type before zero #87054

Conversation

alexey-bataev
Copy link
Member

extension.

Corner case, where sext/zext node cannot be directly promoted because of
the signedness switching. In this case, at first need to cast operand
value to the original type with the its signedness and only after this
cast the result to the new type with the new signedness. Also, need to
adjust cost model to handle this kind of transformation.

Created using spr 1.3.5
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 29, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

extension.

Corner case, where sext/zext node cannot be directly promoted because of
the signedness switching. In this case, at first need to cast operand
value to the original type with the its signedness and only after this
cast the result to the new type with the new signedness. Also, need to
adjust cost model to handle this kind of transformation.


Full diff: https://github.com/llvm/llvm-project/pull/87054.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+87-34)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll (+1-1)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll (+1-1)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll (+1-1)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2875e71081d928..579db52921e676 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8788,23 +8788,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
     unsigned Opcode = ShuffleOrOp;
     unsigned VecOpcode = Opcode;
+    TTI::CastContextHint VecCCH = GetCastContextHint(VL0->getOperand(0));
+    Instruction *VI = VL0;
     if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
         (SrcIt != MinBWs.end() || It != MinBWs.end())) {
+      VI = nullptr;
       // Check if the values are candidates to demote.
       unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
-      if (SrcIt != MinBWs.end()) {
-        SrcBWSz = SrcIt->second.first;
-        SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
-        SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
-      }
-      unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
-      if (BWSz == SrcBWSz) {
-        VecOpcode = Instruction::BitCast;
-      } else if (BWSz < SrcBWSz) {
-        VecOpcode = Instruction::Trunc;
-      } else if (It != MinBWs.end()) {
-        assert(BWSz > SrcBWSz && "Invalid cast!");
-        VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      if (It == MinBWs.end() && SrcIt != MinBWs.end() &&
+          SrcBWSz != SrcIt->second.first &&
+          all_of(VL, [&](Value *V) {
+            return !isKnownNonNegative(V, SimplifyQuery(*DL));
+          }) != SrcIt->second.second) {
+        // Neeed to perform first cast src to original src type.
+        if (SrcBWSz != SrcIt->second.first) {
+          CommonCost += TTI->getCastInstrCost(
+              SrcBWSz < SrcIt->second.first
+                  ? Instruction::Trunc
+                  : (SrcIt->second.second ? Instruction::SExt
+                                          : Instruction::ZExt),
+              SrcVecTy,
+              FixedVectorType::get(
+                  IntegerType::get(F->getContext(), SrcIt->second.first),
+                  VL.size()),
+              VecCCH, CostKind);
+          VecCCH = TTI::CastContextHint::None;
+        }
+      } else {
+        bool Signedness = false;
+        if (SrcIt != MinBWs.end()) {
+          SrcBWSz = SrcIt->second.first;
+          SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
+          SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+          Signedness = SrcIt->second.second;
+        } else {
+          assert(It != MinBWs.end() && "Expected node in MinBWs.");
+          Signedness = It->second.second;
+        }
+        unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+        if (BWSz == SrcBWSz) {
+          VecOpcode = Instruction::BitCast;
+        } else if (BWSz < SrcBWSz) {
+          VecOpcode = Instruction::Trunc;
+        } else {
+          assert(BWSz > SrcBWSz && "Invalid cast!");
+          VecOpcode = Signedness ? Instruction::SExt : Instruction::ZExt;
+        }
       }
     }
     auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
@@ -8814,15 +8843,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                                    TTI::getCastContextHint(VI), CostKind, VI);
     };
     auto GetVectorCost = [=](InstructionCost CommonCost) {
-      // Do not count cost here if minimum bitwidth is in effect and it is just
-      // a bitcast (here it is just a noop).
+      // Do not count cost here if minimum bitwidth is in effect and it is
+      // just a bitcast (here it is just a noop).
       if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
         return CommonCost;
-      auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
-      TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
-      return CommonCost +
-             TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
-                                   VecOpcode == Opcode ? VI : nullptr);
+      return CommonCost + TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy,
+                                                VecCCH, CostKind, VI);
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -12145,18 +12171,37 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           (SrcIt != MinBWs.end() || It != MinBWs.end() ||
            SrcScalarTy != CI->getOperand(0)->getType())) {
         // Check if the values are candidates to demote.
-        unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
-        if (SrcIt != MinBWs.end())
-          SrcBWSz = SrcIt->second.first;
+        unsigned OrigSrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+        unsigned SrcBWSz = OrigSrcBWSz;
         unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
-        if (BWSz == SrcBWSz) {
-          VecOpcode = Instruction::BitCast;
-        } else if (BWSz < SrcBWSz) {
-          VecOpcode = Instruction::Trunc;
-        } else if (SrcIt != MinBWs.end()) {
-          assert(BWSz > SrcBWSz && "Invalid cast!");
-          VecOpcode =
-              SrcIt->second.second ? Instruction::SExt : Instruction::ZExt;
+        if (It == MinBWs.end() && SrcIt != MinBWs.end() &&
+            all_of(E->Scalars, [&](Value *V) {
+              return !isKnownNonNegative(V, SimplifyQuery(*DL));
+            }) != SrcIt->second.second) {
+          // Neeed to perform first cast.
+          InVec = Builder.CreateIntCast(
+              InVec,
+              VectorType::get(
+                  CI->getOperand(0)->getType(),
+                  cast<VectorType>(InVec->getType())->getElementCount()),
+              SrcIt->second.second);
+        } else {
+          bool Signedness = false;
+          if (SrcIt != MinBWs.end()) {
+            SrcBWSz = SrcIt->second.first;
+            Signedness = SrcIt->second.second;
+          } else {
+            assert(It != MinBWs.end() && "Expected node in MinBWs.");
+            Signedness = It->second.second;
+          }
+          if (BWSz == SrcBWSz) {
+            VecOpcode = Instruction::BitCast;
+          } else if (BWSz < SrcBWSz) {
+            VecOpcode = Instruction::Trunc;
+          } else {
+            assert(BWSz > SrcBWSz && "Invalid cast!");
+            VecOpcode = Signedness ? Instruction::SExt : Instruction::ZExt;
+          }
         }
       }
       Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
@@ -14454,10 +14499,18 @@ void BoUpSLP::computeMinimumValueSizes() {
       Value *V = VectorizableTree[Idx]->Scalars.front();
       uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
       if (OrigBitWidth > MaxBitWidth) {
-      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
-      if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
-        ToDemote.push_back(V);
+        APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
+        if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL))) {
+          ToDemote.push_back(V);
+          continue;
+        }
       }
+      auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+      unsigned BitWidth = OrigBitWidth - NumSignBits;
+      if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
+        ++BitWidth;
+      if (BitWidth <= MaxBitWidth)
+        ToDemote.push_back(V);
     }
     RootDemotes.clear();
     IsTopRoot = false;
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
index 1cce52060c479f..866afeea50108c 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
@@ -14,7 +14,7 @@
 ; YAML-NEXT:  Function:        test_i16_extend
 ; YAML-NEXT:  Args:
 ; YAML-NEXT:    - String:          'SLP vectorized with cost '
-; YAML-NEXT:    - Cost:            '-20'
+; YAML-NEXT:    - Cost:            '-16'
 ; YAML-NEXT:    - String:          ' and with tree size '
 ; YAML-NEXT:    - TreeSize:        '5'
 ; YAML-NEXT:  ...
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
index 7c5f9847db1f41..21d4383b3e3563 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
@@ -228,7 +228,7 @@ for.end:                                          ; preds = %for.end.loopexit, %
 ; YAML-NEXT: Function:        test_unrolled_select
 ; YAML-NEXT: Args:
 ; YAML-NEXT:   - String:          'Vectorized horizontal reduction with cost '
-; YAML-NEXT:   - Cost:            '-41'
+; YAML-NEXT:   - Cost:            '-39'
 ; YAML-NEXT:   - String:          ' and with tree size '
 ; YAML-NEXT:   - TreeSize:        '10'
 
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
index 436fba3261d602..1166b1fca826b6 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
@@ -7,7 +7,7 @@ define void @test() {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    store <2 x i64> <i64 -1, i64 0>, ptr @h, align 8
+; CHECK-NEXT:    store <2 x i64> <i64 4294967295, i64 0>, ptr @h, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants