diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index cc168a2266ef69..f026dbe2b2de72 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -197,7 +197,7 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { } void VPBlockBase::deleteCFG(VPBlockBase *Entry) { - for (VPBlockBase *Block : to_vector(depth_first(Entry))) + for (VPBlockBase *Block : to_vector(vp_depth_first_shallow(Entry))) delete Block; } @@ -504,14 +504,15 @@ void VPBasicBlock::print(raw_ostream &O, const Twine &Indent, #endif void VPRegionBlock::dropAllReferences(VPValue *NewValue) { - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) // Drop all references in VPBasicBlocks and replace all uses with // DummyValue. Block->dropAllReferences(NewValue); } void VPRegionBlock::execute(VPTransformState *State) { - ReversePostOrderTraversal RPOT(Entry); + ReversePostOrderTraversal> + RPOT(Entry); if (!isReplicator()) { // Create and register the new vector loop. @@ -565,7 +566,7 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << (isReplicator() ? " " : " ") << getName() << ": {"; auto NewIndent = Indent + " "; - for (auto *BlockBase : depth_first(Entry)) { + for (auto *BlockBase : vp_depth_first_shallow(Entry)) { O << '\n'; BlockBase->print(O, NewIndent, SlotTracker); } @@ -580,7 +581,7 @@ VPlan::~VPlan() { if (Entry) { VPValue DummyValue; - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->dropAllReferences(&DummyValue); VPBlockBase::deleteCFG(Entry); @@ -670,7 +671,7 @@ void VPlan::execute(VPTransformState *State) { State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); // Generate code in the loop pre-header and body. - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->execute(State); VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock(); @@ -756,7 +757,7 @@ void VPlan::print(raw_ostream &O) const { O << " = backedge-taken count\n"; } - for (const VPBlockBase *Block : depth_first(getEntry())) { + for (const VPBlockBase *Block : vp_depth_first_shallow(getEntry())) { O << '\n'; Block->print(O, "", SlotTracker); } @@ -881,7 +882,7 @@ void VPlanPrinter::dump() { OS << "edge [fontname=Courier, fontsize=30]\n"; OS << "compound=true\n"; - for (const VPBlockBase *Block : depth_first(Plan.getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Plan.getEntry())) dumpBlock(Block); OS << "}\n"; @@ -966,7 +967,7 @@ void VPlanPrinter::dumpRegion(const VPRegionBlock *Region) { << DOT::EscapeString(Region->getName()) << "\"\n"; // Dump the blocks of the region. assert(Region->getEntry() && "Region contains no inner blocks."); - for (const VPBlockBase *Block : depth_first(Region->getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Region->getEntry())) dumpBlock(Block); bumpIndent(-1); OS << Indent << "}\n"; @@ -1035,7 +1036,8 @@ void VPUser::printOperands(raw_ostream &O, VPSlotTracker &SlotTracker) const { void VPInterleavedAccessInfo::visitRegion(VPRegionBlock *Region, Old2NewTy &Old2New, InterleavedAccessInfo &IAI) { - ReversePostOrderTraversal RPOT(Region->getEntry()); + ReversePostOrderTraversal> + RPOT(Region->getEntry()); for (VPBlockBase *Base : RPOT) { visitBlock(Base, Old2New, IAI); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h index 5234c9ff650af3..8ee949d2f55297 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -271,6 +271,65 @@ struct GraphTraits> { } }; +/// Helper for GraphTraits specialization that does not traverses through +/// VPRegionBlocks. +template class VPBlockShallowTraversalWrapper { + BlockTy Entry; + +public: + VPBlockShallowTraversalWrapper(BlockTy Entry) : Entry(Entry) {} + BlockTy getEntry() { return Entry; } +}; + +template <> struct GraphTraits> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl::iterator; + + static NodeRef getEntryNode(VPBlockShallowTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +template <> +struct GraphTraits> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = SmallVectorImpl::const_iterator; + + static NodeRef + getEntryNode(VPBlockShallowTraversalWrapper N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +/// Returns an iterator range to traverse the graph starting at \p G in +/// depth-first order. The iterator won't traverse through region blocks. +inline iterator_range< + df_iterator>> +vp_depth_first_shallow(VPBlockBase *G) { + return depth_first(VPBlockShallowTraversalWrapper(G)); +} +inline iterator_range< + df_iterator>> +vp_depth_first_shallow(const VPBlockBase *G) { + return depth_first(VPBlockShallowTraversalWrapper(G)); +} + } // namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 75e53db3c8771c..56a1b563eeeb6b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -44,9 +44,7 @@ static bool hasDuplicates(const SmallVectorImpl &VPBlockVec) { /// \p Region. Checks in this function are generic for VPBlockBases. They are /// not specific for VPBasicBlocks or VPRegionBlocks. static void verifyBlocksInRegion(const VPRegionBlock *Region) { - for (const VPBlockBase *VPB : make_range( - df_iterator::begin(Region->getEntry()), - df_iterator::end(Region->getExiting()))) { + for (const VPBlockBase *VPB : vp_depth_first_shallow(Region->getEntry())) { // Check block's parent. assert(VPB->getParent() == Region && "VPBlockBase has wrong parent");