diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 37c8925b04429..87b819a65138c 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1193,6 +1193,7 @@ class BoUpSLP { InstrElementSize.clear(); UserIgnoreList = nullptr; PostponedGathers.clear(); + ValueToGatherNodes.clear(); } unsigned getTreeSize() const { return VectorizableTree.size(); } @@ -2955,6 +2956,10 @@ class BoUpSLP { /// handle order of the vector instructions and shuffles. SetVector PostponedGathers; + using ValueToGatherNodesMap = + DenseMap>; + ValueToGatherNodesMap ValueToGatherNodes; + /// This POD struct describes one external user in the vectorized tree. struct ExternalUser { ExternalUser(Value *S, llvm::User *U, int L) @@ -8148,6 +8153,16 @@ static T *performExtractsShuffleAction( } InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { + // Build a map for gathered scalars to the nodes where they are used. + ValueToGatherNodes.clear(); + for (const std::unique_ptr &EntryPtr : VectorizableTree) { + if (EntryPtr->State != TreeEntry::NeedToGather) + continue; + for (Value *V : EntryPtr->Scalars) + if (!isConstant(V)) + ValueToGatherNodes.try_emplace(V).first->getSecond().insert( + EntryPtr.get()); + } InstructionCost Cost = 0; LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); @@ -8424,51 +8439,6 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef VL, return false; return true; }; - // Build a lists of values to tree entries. - DenseMap> ValueToTEs; - for (const std::unique_ptr &EntryPtr : VectorizableTree) { - if (EntryPtr.get() == TE) - continue; - if (EntryPtr->State != TreeEntry::NeedToGather) - continue; - if (!any_of(EntryPtr->Scalars, [&GatheredScalars](Value *V) { - return GatheredScalars.contains(V); - })) - continue; - assert(EntryPtr->UserTreeIndices.size() == 1 && - "Expected only single user of the gather node."); - Instruction &EntryUserInst = - getLastInstructionInBundle(EntryPtr->UserTreeIndices.front().UserTE); - PHINode *EntryPHI = dyn_cast( - EntryPtr->UserTreeIndices.front().UserTE->getMainOp()); - if (&UserInst == &EntryUserInst && !EntryPHI) { - // If 2 gathers are operands of the same entry, compare operands indices, - // use the earlier one as the base. - if (TE->UserTreeIndices.front().UserTE == - EntryPtr->UserTreeIndices.front().UserTE && - TE->UserTreeIndices.front().EdgeIdx < - EntryPtr->UserTreeIndices.front().EdgeIdx) - continue; - } - // Check if the user node of the TE comes after user node of EntryPtr, - // otherwise EntryPtr depends on TE. - auto *EntryI = - EntryPHI - ? EntryPHI - ->getIncomingBlock(EntryPtr->UserTreeIndices.front().EdgeIdx) - ->getTerminator() - : &EntryUserInst; - if (!CheckOrdering(EntryI) && - (ParentBB != EntryI->getParent() || - TE->UserTreeIndices.front().UserTE != - EntryPtr->UserTreeIndices.front().UserTE || - TE->UserTreeIndices.front().EdgeIdx < - EntryPtr->UserTreeIndices.front().EdgeIdx)) - continue; - for (Value *V : EntryPtr->Scalars) - if (!isConstant(V)) - ValueToTEs.try_emplace(V).first->getSecond().insert(EntryPtr.get()); - } // Find all tree entries used by the gathered values. If no common entries // found - not a shuffle. // Here we build a set of tree nodes for each gathered value and trying to @@ -8483,9 +8453,45 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef VL, continue; // Build a list of tree entries where V is used. SmallPtrSet VToTEs; - auto It = ValueToTEs.find(V); - if (It != ValueToTEs.end()) - VToTEs = It->second; + for (const TreeEntry *TEPtr : ValueToGatherNodes.find(V)->second) { + if (TEPtr == TE) + continue; + if (!any_of(TEPtr->Scalars, [&GatheredScalars](Value *V) { + return GatheredScalars.contains(V); + })) + continue; + assert(TEPtr->UserTreeIndices.size() == 1 && + "Expected only single user of the gather node."); + Instruction &EntryUserInst = + getLastInstructionInBundle(TEPtr->UserTreeIndices.front().UserTE); + PHINode *EntryPHI = + dyn_cast(TEPtr->UserTreeIndices.front().UserTE->getMainOp()); + if (&UserInst == &EntryUserInst && !EntryPHI) { + // If 2 gathers are operands of the same entry, compare operands + // indices, use the earlier one as the base. + if (TE->UserTreeIndices.front().UserTE == + TEPtr->UserTreeIndices.front().UserTE && + TE->UserTreeIndices.front().EdgeIdx < + TEPtr->UserTreeIndices.front().EdgeIdx) + continue; + } + // Check if the user node of the TE comes after user node of EntryPtr, + // otherwise EntryPtr depends on TE. + auto *EntryI = EntryPHI + ? EntryPHI + ->getIncomingBlock( + TEPtr->UserTreeIndices.front().EdgeIdx) + ->getTerminator() + : &EntryUserInst; + if (!CheckOrdering(EntryI) && + (ParentBB != EntryI->getParent() || + TE->UserTreeIndices.front().UserTE != + TEPtr->UserTreeIndices.front().UserTE || + TE->UserTreeIndices.front().EdgeIdx < + TEPtr->UserTreeIndices.front().EdgeIdx)) + continue; + VToTEs.insert(TEPtr); + } if (const TreeEntry *VTE = getTreeEntry(V)) { Instruction &EntryUserInst = getLastInstructionInBundle(VTE); if (&EntryUserInst == &UserInst || !CheckOrdering(&EntryUserInst))