diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 6b8f12891cd5e..b4869512725f6 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8135,6 +8135,13 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, CM.getInLoopReductions(), Hints.allowReordering()); VPlanTransforms::simplifyRecipes(*VPlan0); + VPlanTransforms::handleEarlyExits(*VPlan0, Legal->hasUncountableEarlyExit()); + VPlanTransforms::addMiddleCheck(*VPlan0, CM.foldTailByMasking()); + RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::createLoopRegions, *VPlan0); + if (CM.foldTailByMasking()) + RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::foldTailByMasking, *VPlan0); + RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::introduceMasksAndLinearize, + *VPlan0); auto MaxVFTimes2 = MaxVF * 2; for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) { @@ -8181,13 +8188,14 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( return !CM.requiresScalarEpilogue(VF.isVector()); }, Range); - VPlanTransforms::handleEarlyExits(*Plan, Legal->hasUncountableEarlyExit()); - VPlanTransforms::addMiddleCheck(*Plan, RequiresScalarEpilogueCheck, - CM.foldTailByMasking()); - - RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::createLoopRegions, *Plan); - if (CM.foldTailByMasking()) - RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::foldTailByMasking, *Plan); + // Update the branch in the middle block if a scalar epilogue is required. + VPBasicBlock *MiddleVPBB = Plan->getMiddleBlock(); + if (!RequiresScalarEpilogueCheck && MiddleVPBB->getNumSuccessors() == 2) { + auto *BranchOnCond = cast(MiddleVPBB->getTerminator()); + assert(MiddleVPBB->getSuccessors()[1] == Plan->getScalarPreheader() && + "second successor must be scalar preheader"); + BranchOnCond->setOperand(0, Plan->getFalse()); + } // Don't use getDecisionAndClampRange here, because we don't know the UF // so this function is better to be conservative, rather than to split @@ -8239,11 +8247,6 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( InterleaveGroups.insert(IG); } - // --------------------------------------------------------------------------- - // Predicate and linearize the top-level loop region. - // --------------------------------------------------------------------------- - RUN_VPLAN_PASS_NO_VERIFY(VPlanTransforms::introduceMasksAndLinearize, *Plan); - // --------------------------------------------------------------------------- // Construct wide recipes and apply predication for original scalar // VPInstructions in the loop. @@ -8256,7 +8259,6 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( ReversePostOrderTraversal> RPOT( HeaderVPBB); - auto *MiddleVPBB = Plan->getMiddleBlock(); VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi(); // Collect blocks that need predication for in-loop reduction recipes. @@ -8433,8 +8435,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) { /*AllowReordering=*/false); VPlanTransforms::handleEarlyExits(*Plan, /*HasUncountableExit*/ false); - VPlanTransforms::addMiddleCheck(*Plan, /*RequiresScalarEpilogue*/ true, - /*TailFolded*/ false); + VPlanTransforms::addMiddleCheck(*Plan, /*TailFolded*/ false); VPlanTransforms::createLoopRegions(*Plan); diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index 468193d9e10eb..ab0d81ea1b451 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -971,9 +971,7 @@ void VPlanTransforms::handleEarlyExits(VPlan &Plan, } } -void VPlanTransforms::addMiddleCheck(VPlan &Plan, - bool RequiresScalarEpilogueCheck, - bool TailFolded) { +void VPlanTransforms::addMiddleCheck(VPlan &Plan, bool TailFolded) { auto *MiddleVPBB = cast( Plan.getScalarHeader()->getSinglePredecessor()->getPredecessors()[0]); // If MiddleVPBB has a single successor then the original loop does not exit @@ -1006,9 +1004,7 @@ void VPlanTransforms::addMiddleCheck(VPlan &Plan, DebugLoc LatchDL = LatchVPBB->getTerminator()->getDebugLoc(); VPBuilder Builder(MiddleVPBB); VPValue *Cmp; - if (!RequiresScalarEpilogueCheck) - Cmp = Plan.getFalse(); - else if (TailFolded) + if (TailFolded) Cmp = Plan.getTrue(); else Cmp = Builder.createICmp(CmpInst::ICMP_EQ, Plan.getTripCount(), diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 5f060b32da847..d10ef23dd05b2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -155,9 +155,7 @@ struct VPlanTransforms { /// If a check is needed to guard executing the scalar epilogue loop, it will /// be added to the middle block. - LLVM_ABI_FOR_TEST static void addMiddleCheck(VPlan &Plan, - bool RequiresScalarEpilogueCheck, - bool TailFolded); + LLVM_ABI_FOR_TEST static void addMiddleCheck(VPlan &Plan, bool TailFolded); // Create a check to \p Plan to see if the vector loop should be executed. static void addMinimumIterationCheck( diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h index 2322cf340ff07..472c04b17863b 100644 --- a/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h +++ b/llvm/unittests/Transforms/Vectorize/VPlanTestBase.h @@ -76,7 +76,7 @@ class VPlanTestIRBase : public testing::Test { {}, PSE); VPlanTransforms::handleEarlyExits(*Plan, HasUncountableExit); - VPlanTransforms::addMiddleCheck(*Plan, true, false); + VPlanTransforms::addMiddleCheck(*Plan, false); VPlanTransforms::createLoopRegions(*Plan); return Plan;