diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 9fa940cda8e98..272f151e05a4a 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -928,8 +928,8 @@ class LoopVectorizationCostModel { /// user options, for the given register kind. bool useMaxBandwidth(TargetTransformInfo::RegisterKind RegKind); - /// \return True if register pressure should be calculated for the given VF. - bool shouldCalculateRegPressureForVF(ElementCount VF); + /// \return True if register pressure should be considered for the given VF. + bool shouldConsiderRegPressureForVF(ElementCount VF); /// \return The size (in bits) of the smallest and widest types in the code /// that needs to be vectorized. We ignore values that remain scalar such as @@ -3700,7 +3700,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { return FixedScalableVFPair::getNone(); } -bool LoopVectorizationCostModel::shouldCalculateRegPressureForVF( +bool LoopVectorizationCostModel::shouldConsiderRegPressureForVF( ElementCount VF) { if (!useMaxBandwidth(VF.isScalable() ? TargetTransformInfo::RGK_ScalableVector @@ -4147,8 +4147,9 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { P->vectorFactors().end()); SmallVector RUs; - if (CM.useMaxBandwidth(TargetTransformInfo::RGK_ScalableVector) || - CM.useMaxBandwidth(TargetTransformInfo::RGK_FixedWidthVector)) + if (any_of(VFs, [this](ElementCount VF) { + return CM.shouldConsiderRegPressureForVF(VF); + })) RUs = calculateRegisterUsageForPlan(*P, VFs, TTI, CM.ValuesToIgnore); for (unsigned I = 0; I < VFs.size(); I++) { @@ -4160,7 +4161,7 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { /// If the register pressure needs to be considered for VF, /// don't consider the VF as valid if it exceeds the number /// of registers for the target. - if (CM.shouldCalculateRegPressureForVF(VF) && + if (CM.shouldConsiderRegPressureForVF(VF) && RUs[I].exceedsMaxNumRegs(TTI, ForceTargetNumVectorRegs)) continue; @@ -6996,8 +6997,9 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { P->vectorFactors().end()); SmallVector RUs; - if (CM.useMaxBandwidth(TargetTransformInfo::RGK_ScalableVector) || - CM.useMaxBandwidth(TargetTransformInfo::RGK_FixedWidthVector)) + if (any_of(VFs, [this](ElementCount VF) { + return CM.shouldConsiderRegPressureForVF(VF); + })) RUs = calculateRegisterUsageForPlan(*P, VFs, TTI, CM.ValuesToIgnore); for (unsigned I = 0; I < VFs.size(); I++) { @@ -7023,7 +7025,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { InstructionCost Cost = cost(*P, VF); VectorizationFactor CurrentFactor(VF, Cost, ScalarCost); - if (CM.shouldCalculateRegPressureForVF(VF) && + if (CM.shouldConsiderRegPressureForVF(VF) && RUs[I].exceedsMaxNumRegs(TTI, ForceTargetNumVectorRegs)) { LLVM_DEBUG(dbgs() << "LV(REG): Not considering vector loop of width " << VF << " because it uses too many registers\n"); diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/reg-usage.ll b/llvm/test/Transforms/LoopVectorize/AArch64/reg-usage.ll index e51a925040a49..01d103264fafe 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/reg-usage.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/reg-usage.ll @@ -14,7 +14,7 @@ define void @get_invariant_reg_usage(ptr %z) { ; CHECK-LABEL: LV: Checking a loop in 'get_invariant_reg_usage' -; CHECK: LV(REG): VF = vscale x 16 +; CHECK: LV(REG): VF = 16 ; CHECK-NEXT: LV(REG): Found max usage: 2 item ; CHECK-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 2 registers ; CHECK-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 1 registers