diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h index a509ebf6a7e1b..af23813da569b 100644 --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -276,9 +276,12 @@ class LoopVectorizationLegality { bool canVectorizeFPMath(bool EnableStrictReductions); /// Return true if we can vectorize this loop while folding its tail by - /// masking, and mark all respective loads/stores for masking. - /// This object's state is only modified iff this function returns true. - bool prepareToFoldTailByMasking(); + /// masking. + bool canFoldTailByMasking() const; + + /// Mark all respective loads/stores for masking. Must only be called when + /// ail-folding is possible. + void prepareToFoldTailByMasking(); /// Returns the primary induction variable. PHINode *getPrimaryInduction() { return PrimaryInduction; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 9de49d1bcfeac..569550991dcaa 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1533,7 +1533,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return Result; } -bool LoopVectorizationLegality::prepareToFoldTailByMasking() { +bool LoopVectorizationLegality::canFoldTailByMasking() const { LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n"); @@ -1591,8 +1591,24 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() { LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n"); - MaskedOp.insert(TmpMaskedOp.begin(), TmpMaskedOp.end()); return true; } +void LoopVectorizationLegality::prepareToFoldTailByMasking() { + // The list of pointers that we can safely read and write to remains empty. + SmallPtrSet SafePointers; + + // Collect masked ops in temporary set first to avoid partially populating + // MaskedOp if a block cannot be predicated. + SmallPtrSet TmpMaskedOp; + + // Check and mark all blocks for predication, including those that ordinarily + // do not need predication such as the header block. + for (BasicBlock *BB : TheLoop->blocks()) { + bool R = blockCanBePredicated(BB, SafePointers, MaskedOp); + (void)R; + assert(R && "Must be able to predicate block when tail-folding."); + } +} + } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index a9ee9f62197e1..350f9142b52e0 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1508,7 +1508,7 @@ class LoopVectorizationCostModel { /// \param UserIC User specific interleave count. void setTailFoldingStyles(bool IsScalableVF, unsigned UserIC) { assert(!ChosenTailFoldingStyle && "Tail folding must not be selected yet."); - if (!Legal->prepareToFoldTailByMasking()) { + if (!Legal->canFoldTailByMasking()) { ChosenTailFoldingStyle = std::make_pair(TailFoldingStyle::None, TailFoldingStyle::None); return; @@ -7309,6 +7309,9 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { CM.invalidateCostModelingDecisions(); } + if (CM.foldTailByMasking()) + Legal->prepareToFoldTailByMasking(); + ElementCount MaxUserVF = UserVF.isScalable() ? MaxFactors.ScalableVF : MaxFactors.FixedVF; bool UserVFIsLegal = ElementCount::isKnownLE(UserVF, MaxUserVF);