Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLP][NFC] Unify code for cost estimation/codegen for buildvector, NFC. #73182

Merged
merged 1 commit into from
Nov 30, 2023

Conversation

alexey-bataev
Copy link
Member

This just moves towards reusing same function for both cost
estimation/codegen for buildvector.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 22, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

This just moves towards reusing same function for both cost
estimation/codegen for buildvector.


Full diff: https://github.com/llvm/llvm-project/pull/73182.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+142-13)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b73de2ed6ff9a38..226e5a20a2fadb6 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7344,6 +7344,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
       V2 = getAllOnesValue(
           *R.DL,
           FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF));
+    } else if (!V1 && V2) {
+      // Shuffle vector and tree node.
+      unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+      const TreeEntry *E1 = P1.get<const TreeEntry *>();
+      CommonVF = std::max(VF, E1->getVectorFactor());
+      assert(all_of(Mask,
+                    [=](int Idx) {
+                      return Idx < 2 * static_cast<int>(CommonVF);
+                    }) &&
+             "All elements in mask must be less than 2 * CommonVF.");
+      if (E1->Scalars.size() == VF && VF != CommonVF) {
+        SmallVector<int> E1Mask = E1->getCommonMask();
+        assert(!E1Mask.empty() && "Expected non-empty common mask.");
+        for (int &Idx : CommonMask) {
+          if (Idx == PoisonMaskElem)
+            continue;
+          if (Idx >= static_cast<int>(CommonVF))
+            Idx = E1Mask[Idx - CommonVF] + VF;
+        }
+        CommonVF = VF;
+      }
+      V1 = Constant::getNullValue(
+          FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
+      V2 = getAllOnesValue(
+          *R.DL,
+          FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
     } else {
       assert(V1 && V2 && "Expected both vectors.");
       unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
@@ -7380,7 +7406,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
         R(R), CheckedExtracts(CheckedExtracts) {}
   Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
                         ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds,
-                        unsigned NumParts) {
+                        unsigned NumParts, bool &UseVecBaseAsInput) {
+    UseVecBaseAsInput = false;
     if (Mask.empty())
       return nullptr;
     Value *VecBase = nullptr;
@@ -7403,6 +7430,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
                            Data.value() == VL[Data.index()]);
                  });
         });
+    SmallPtrSet<Value *, 4> UniqueBases;
     unsigned SliceSize = VL.size() / NumParts;
     for (unsigned Part = 0; Part < NumParts; ++Part) {
       ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize);
@@ -7417,13 +7445,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
         // vectorized tree.
         // Also, avoid adjusting the cost for extractelements with multiple uses
         // in different graph entries.
+        auto *EE = cast<ExtractElementInst>(V);
+        VecBase = EE->getVectorOperand();
+        UniqueBases.insert(VecBase);
         const TreeEntry *VE = R.getTreeEntry(V);
         if (!CheckedExtracts.insert(V).second ||
             !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) ||
             (VE && VE != E))
           continue;
-        auto *EE = cast<ExtractElementInst>(V);
-        VecBase = EE->getVectorOperand();
         std::optional<unsigned> EEIdx = getExtractIndex(EE);
         if (!EEIdx)
           continue;
@@ -7462,6 +7491,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     CommonMask.assign(Mask.begin(), Mask.end());
     transformMaskAfterShuffle(CommonMask, CommonMask);
     SameNodesEstimated = false;
+    if (NumParts != 1 && UniqueBases.size() != 1) {
+      UseVecBaseAsInput = true;
+      VecBase = Constant::getNullValue(
+          FixedVectorType::get(VL.front()->getType(), CommonMask.size()));
+    }
     return VecBase;
   }
   void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -7511,19 +7545,70 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     if (!SameNodesEstimated && InVectors.size() == 1)
       InVectors.emplace_back(&E1);
   }
+  /// Adds 2 input vectors and the mask for their shuffling.
+  void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
+    // May come only for shuffling of 2 vectors with extractelements, already
+    // handled in adjustExtracts.
+    assert(InVectors.size() == 1 &&
+           all_of(enumerate(CommonMask),
+                  [&](auto P) {
+                    if (P.value() == PoisonMaskElem)
+                      return Mask[P.index()] == PoisonMaskElem;
+                    auto *EI =
+                        cast<ExtractElementInst>(InVectors.front()
+                                                     .get<const TreeEntry *>()
+                                                     ->Scalars[P.index()]);
+                    return EI->getVectorOperand() == V1 ||
+                           EI->getVectorOperand() == V2;
+                  }) &&
+           "Expected extractelement vectors.");
+  }
   /// Adds another one input vector and the mask for the shuffling.
-  void add(Value *V1, ArrayRef<int> Mask) {
+  void add(Value *V1, ArrayRef<int> Mask, bool ForExtracts = false) {
     if (InVectors.empty()) {
-      assert(CommonMask.empty() && "Expected empty input mask/vectors.");
+      assert(CommonMask.empty() && !ForExtracts &&
+             "Expected empty input mask/vectors.");
       CommonMask.assign(Mask.begin(), Mask.end());
       InVectors.assign(1, V1);
       return;
     }
-    assert(InVectors.size() == 1 && InVectors.front().is<const TreeEntry *>() &&
-           !CommonMask.empty() && "Expected only single entry from extracts.");
+    if (ForExtracts) {
+      // No need to add vectors here, already handled them in adjustExtracts.
+      assert(InVectors.size() == 1 &&
+             InVectors.front().is<const TreeEntry *>() && !CommonMask.empty() &&
+             all_of(enumerate(CommonMask),
+                    [&](auto P) {
+                      Value *Scalar = InVectors.front()
+                                          .get<const TreeEntry *>()
+                                          ->Scalars[P.index()];
+                      if (P.value() == PoisonMaskElem)
+                        return P.value() == Mask[P.index()] ||
+                               isa<UndefValue>(Scalar);
+                      if (isa<Constant>(V1))
+                        return true;
+                      auto *EI = cast<ExtractElementInst>(Scalar);
+                      return EI->getVectorOperand() == V1;
+                    }) &&
+             "Expected only tree entry for extractelement vectors.");
+      return;
+    }
+    assert(!InVectors.empty() && !CommonMask.empty() &&
+           "Expected only tree entries from extracts/reused buildvectors.");
+    unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+    if (InVectors.size() == 2) {
+      Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
+      transformMaskAfterShuffle(CommonMask, CommonMask);
+      VF = std::max<unsigned>(VF, CommonMask.size());
+    } else if (const auto *InTE =
+                   InVectors.front().dyn_cast<const TreeEntry *>()) {
+      VF = std::max(VF, InTE->getVectorFactor());
+    } else {
+      VF = std::max(
+          VF, cast<FixedVectorType>(InVectors.front().get<Value *>()->getType())
+                  ->getNumElements());
+    }
     InVectors.push_back(V1);
-    unsigned VF = CommonMask.size();
-    for (unsigned Idx = 0; Idx < VF; ++Idx)
+    for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
       if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
         CommonMask[Idx] = Mask[Idx] + VF;
   }
@@ -7640,6 +7725,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       reorderScalars(GatheredScalars, ReorderMask);
     SmallVector<int> Mask;
     SmallVector<int> ExtractMask;
+    Value *ExtractVecBase = nullptr;
+    bool UseVecBaseAsInput = false;
     SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles;
     SmallVector<SmallVector<const TreeEntry *>> Entries;
     SmallVector<std::optional<TTI::ShuffleKind>> ExtractShuffles;
@@ -7653,7 +7740,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
           tryToGatherExtractElements(GatheredScalars, ExtractMask, NumParts);
       if (!ExtractShuffles.empty()) {
         if (Value *VecBase = Estimator.adjustExtracts(
-                E, ExtractMask, ExtractShuffles, NumParts)) {
+                E, ExtractMask, ExtractShuffles, NumParts, UseVecBaseAsInput)) {
           if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
             if (VF == VecBaseTy->getNumElements() &&
                 GatheredScalars.size() != VF) {
@@ -7748,6 +7835,48 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                                        ScalarTy, GatheredScalars.size())));
           });
     }
+    if (!ExtractShuffles.empty()) {
+      Value *Vec1 = nullptr;
+      // Gather of extractelements can be represented as just a shuffle of
+      // a single/two vectors the scalars are extracted from.
+      // Find input vectors.
+      Value *Vec2 = nullptr;
+      for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
+        if (!Mask.empty() && Mask[I] != PoisonMaskElem)
+          ExtractMask[I] = PoisonMaskElem;
+      }
+      if (UseVecBaseAsInput) {
+        Vec1 = ExtractVecBase;
+      } else {
+        for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
+          if (ExtractMask[I] == PoisonMaskElem)
+            continue;
+          if (isa<UndefValue>(E->Scalars[I]))
+            continue;
+          auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
+          Value *VecOp = EI->getVectorOperand();
+          if (const auto *TE = getTreeEntry(VecOp))
+            if (TE->VectorizedValue)
+              VecOp = TE->VectorizedValue;
+          if (!Vec1) {
+            Vec1 = VecOp;
+          } else if (Vec1 != EI->getVectorOperand()) {
+            assert((!Vec2 || Vec2 == EI->getVectorOperand()) &&
+                   "Expected only 1 or 2 vectors shuffle.");
+            Vec2 = VecOp;
+          }
+        }
+      }
+      if (Vec2) {
+        Estimator.add(Vec1, Vec2, ExtractMask);
+      } else if (Vec1) {
+        Estimator.add(Vec1, ExtractMask, /*ForExtracts=*/true);
+      } else {
+        Estimator.add(PoisonValue::get(FixedVectorType::get(
+                          ScalarTy, GatheredScalars.size())),
+                      ExtractMask, /*ForExtracts=*/true);
+      }
+    }
     if (!all_of(GatheredScalars, PoisonValue::classof)) {
       auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size());
       bool SameGathers = VL.equals(Gathers);
@@ -10341,7 +10470,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
       InVectors.push_back(V1);
   }
   /// Adds another one input vector and the mask for the shuffling.
-  void add(Value *V1, ArrayRef<int> Mask) {
+  void add(Value *V1, ArrayRef<int> Mask, bool = false) {
     if (InVectors.empty()) {
       if (!isa<FixedVectorType>(V1->getType())) {
         V1 = createShuffle(V1, nullptr, CommonMask);
@@ -10880,13 +11009,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
         IsUsedInExpr &= FindReusedSplat(
             ExtractMask,
             cast<FixedVectorType>(Vec1->getType())->getNumElements());
-        ShuffleBuilder.add(Vec1, ExtractMask);
+        ShuffleBuilder.add(Vec1, ExtractMask, /*ForExtracts=*/true);
         IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1);
       } else {
         IsUsedInExpr = false;
         ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get(
                                ScalarTy, GatheredScalars.size())),
-                           ExtractMask);
+                           ExtractMask, /*ForExtracts=*/true);
       }
     }
     if (!GatherShuffles.empty()) {

@dtcxzyw dtcxzyw changed the title y [SLP][NFC] Unify code for cost estimation/codegen for buildvector, NFC. Nov 23, 2023
This just moves towards reusing same function for both cost
estimation/codegen for buildvector.
@@ -10341,7 +10470,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
InVectors.push_back(V1);
}
/// Adds another one input vector and the mask for the shuffling.
void add(Value *V1, ArrayRef<int> Mask) {
void add(Value *V1, ArrayRef<int> Mask, bool = false) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of this bool argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a unification with the cost estimator interface, check void add(Value *V1, ArrayRef<int> Mask, bool ForExtracts = false) from ShuffleCostEstimator. Same functions with the same args will be used for both code gen/cost estimation.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@alexey-bataev alexey-bataev merged commit ba52310 into llvm:main Nov 30, 2023
3 checks passed
@alexey-bataev alexey-bataev deleted the SLPUnifyCostCodegen branch November 30, 2023 15:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants