diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h index 1b37aabaafae8..bd0d8882a0b2a 100644 --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -435,10 +435,13 @@ class LoopVectorizationLegality { return LAI->getDepChecker().getStoreLoadForwardSafeDistanceInBits(); } - /// Returns true if vector representation of the instruction \p I - /// requires mask. - bool isMaskRequired(const Instruction *I) const { - return MaskedOp.contains(I); + /// Returns true if instruction \p I requires a mask for vectorization. + /// This accounts for both control flow masking (conditionally executed + /// blocks) and tail-folding masking (predicated loop vectorization). + bool isMaskRequired(const Instruction *I, bool LoopPredicated) const { + if (LoopPredicated) + return PredMaskedOps.contains(I); + return UnpredMaskedOps.contains(I); } /// Returns true if there is at least one function call in the loop which @@ -714,9 +717,16 @@ class LoopVectorizationLegality { AssumptionCache *AC; /// While vectorizing these instructions we have to generate a - /// call to the appropriate masked intrinsic or drop them in case of - /// conditional assumes. - SmallPtrSet MaskedOp; + /// call to the appropriate masked intrinsic or drop them. + /// In order to differentiate between control flow introduced at the source + /// level and that introduced by the loop vectoriser during tail-folding, we + /// keep two lists: + /// 1) UnpredMaskedOp - instructions that need masking if we are + /// in conditionally executed block. + /// 2) PredMaskedOp - instructions that need masking if we are in + /// a predicated loop. + SmallPtrSet UnpredMaskedOps; + SmallPtrSet PredMaskedOps; /// Contains all identified histogram operations, which are sequences of /// load -> update -> store instructions where multiple lanes in a vector diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 26e2d44bdc9e6..e2dd2ff56649b 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1615,7 +1615,7 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { // We must be able to predicate all blocks that need to be predicated. if (blockNeedsPredication(BB) && - !blockCanBePredicated(BB, SafePointers, MaskedOp)) { + !blockCanBePredicated(BB, SafePointers, UnpredMaskedOps)) { reportVectorizationFailure( "Control flow cannot be substituted for a select", "NoCFGForSelect", ORE, TheLoop, BB->getTerminator()); @@ -2158,7 +2158,8 @@ void LoopVectorizationLegality::prepareToFoldTailByMasking() { // Mark all blocks for predication, including those that ordinarily do not // need predication such as the header block. for (BasicBlock *BB : TheLoop->blocks()) { - [[maybe_unused]] bool R = blockCanBePredicated(BB, SafePointers, MaskedOp); + [[maybe_unused]] bool R = + blockCanBePredicated(BB, SafePointers, PredMaskedOps); assert(R && "Must be able to predicate block when tail-folding."); } } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 9a94d29ba3307..91eb731aa1fa5 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1221,6 +1221,10 @@ class LoopVectorizationCostModel { /// \p VF is the vectorization factor that will be used to vectorize \p I. bool isScalarWithPredication(Instruction *I, ElementCount VF) const; + /// Wrapper function for LoopVectorizationLegality::isMaskRequired, + /// that passes the \p I and if we fold tail. + bool isMaskRequired(Instruction *I) const; + /// Returns true if \p I is an instruction that needs to be predicated /// at runtime. The result is independent of the predication mechanism. /// Superset of instructions that return true for isScalarWithPredication. @@ -2833,12 +2837,16 @@ bool LoopVectorizationCostModel::isScalarWithPredication( } } +bool LoopVectorizationCostModel::isMaskRequired(Instruction *I) const { + return Legal->isMaskRequired(I, foldTailByMasking()); +} + // TODO: Fold into LoopVectorizationLegality::isMaskRequired. bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const { // TODO: We can use the loop-preheader as context point here and get // context sensitive reasoning for isSafeToSpeculativelyExecute. if (isSafeToSpeculativelyExecute(I) || - (isa(I) && !Legal->isMaskRequired(I)) || + (isa(I) && !isMaskRequired(I)) || isa(I)) return false; @@ -2863,7 +2871,7 @@ bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const { case Instruction::Call: // Side-effects of a Call are assumed to be non-invariant, needing a // (fold-tail) mask. - assert(Legal->isMaskRequired(I) && + assert(isMaskRequired(I) && "should have returned earlier for calls not needing a mask"); return true; case Instruction::Load: @@ -2990,8 +2998,7 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( // (either a gap at the end of a load-access that may result in a speculative // load, or any gaps in a store-access). bool PredicatedAccessRequiresMasking = - blockNeedsPredicationForAnyReason(I->getParent()) && - Legal->isMaskRequired(I); + blockNeedsPredicationForAnyReason(I->getParent()) && isMaskRequired(I); bool LoadAccessWithGapsRequiresEpilogMasking = isa(I) && Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); @@ -5260,7 +5267,7 @@ LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, "Stride should be 1 or -1 for consecutive memory access"); const Align Alignment = getLoadStoreAlignment(I); InstructionCost Cost = 0; - if (Legal->isMaskRequired(I)) { + if (isMaskRequired(I)) { unsigned IID = I->getOpcode() == Instruction::Load ? Intrinsic::masked_load : Intrinsic::masked_store; @@ -5329,8 +5336,8 @@ LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, : Intrinsic::masked_scatter; return TTI.getAddressComputationCost(PtrTy, nullptr, nullptr, CostKind) + TTI.getMemIntrinsicInstrCost( - MemIntrinsicCostAttributes(IID, VectorTy, Ptr, - Legal->isMaskRequired(I), Alignment, I), + MemIntrinsicCostAttributes(IID, VectorTy, Ptr, isMaskRequired(I), + Alignment, I), CostKind); } @@ -5360,12 +5367,11 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, (isa(I) && !Group->isFull()); InstructionCost Cost = TTI.getInterleavedMemoryOpCost( InsertPos->getOpcode(), WideVecTy, Group->getFactor(), Indices, - Group->getAlign(), AS, CostKind, Legal->isMaskRequired(I), - UseMaskForGaps); + Group->getAlign(), AS, CostKind, isMaskRequired(I), UseMaskForGaps); if (Group->isReverse()) { // TODO: Add support for reversed masked interleaved access. - assert(!Legal->isMaskRequired(I) && + assert(!isMaskRequired(I) && "Reverse masked interleaved access not supported."); Cost += Group->getNumMembers() * TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, @@ -5903,7 +5909,7 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) { continue; } - bool MaskRequired = Legal->isMaskRequired(CI); + bool MaskRequired = isMaskRequired(CI); // Compute corresponding vector type for return value and arguments. Type *RetTy = toVectorizedTy(ScalarRetTy, VF); for (Type *ScalarTy : ScalarTys) @@ -7610,7 +7616,7 @@ VPWidenMemoryRecipe *VPRecipeBuilder::tryToWidenMemory(VPInstruction *VPI, return nullptr; VPValue *Mask = nullptr; - if (Legal->isMaskRequired(I)) + if (CM.isMaskRequired(I)) Mask = getBlockInMask(Builder.getInsertBlock()); // Determine if the pointer operand of the access is either consecutive or @@ -7823,7 +7829,7 @@ VPSingleDefRecipe *VPRecipeBuilder::tryToWidenCall(VPInstruction *VPI, // vector variant at this VF requires a mask, so we synthesize an // all-true mask. VPValue *Mask = nullptr; - if (Legal->isMaskRequired(CI)) + if (CM.isMaskRequired(CI)) Mask = getBlockInMask(Builder.getInsertBlock()); else Mask = Plan.getOrAddLiveIn( @@ -7946,7 +7952,7 @@ VPHistogramRecipe *VPRecipeBuilder::tryToWidenHistogram(const HistogramInfo *HI, // In case of predicated execution (due to tail-folding, or conditional // execution, or both), pass the relevant mask. - if (Legal->isMaskRequired(HI->Store)) + if (CM.isMaskRequired(HI->Store)) HGramOps.push_back(getBlockInMask(Builder.getInsertBlock())); return new VPHistogramRecipe(Opcode, HGramOps, VPI->getDebugLoc());