From d2cbd8f3becb58511579afd6388f9e0bc1c4ba13 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 3 Nov 2025 15:57:31 +0000 Subject: [PATCH] [VPlan] Detect and create partial reductions in VPlan. (NFCI) As a first step, move the existing partial reduction detection logic to VPlan, trying to preserve the existing code structure & behavior as closely as possible. With this, partial reductions are detected and created together in a single step. This allows forming partial reductions and bundling them up if profitable together in a follow-up. --- .../Transforms/Vectorize/LoopVectorize.cpp | 245 +---------- .../Transforms/Vectorize/VPRecipeBuilder.h | 60 +-- llvm/lib/Transforms/Vectorize/VPlan.h | 6 +- .../Vectorize/VPlanConstruction.cpp | 383 ++++++++++++++++++ .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 2 +- .../Transforms/Vectorize/VPlanTransforms.h | 10 + 6 files changed, 413 insertions(+), 293 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 356d759b94799..f42a6d8c8714c 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7985,178 +7985,6 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI, return Recipe; } -/// Find all possible partial reductions in the loop and track all of those that -/// are valid so recipes can be formed later. -void VPRecipeBuilder::collectScaledReductions(VFRange &Range) { - // Find all possible partial reductions. - SmallVector> - PartialReductionChains; - for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) { - getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range, - PartialReductionChains); - } - - // A partial reduction is invalid if any of its extends are used by - // something that isn't another partial reduction. This is because the - // extends are intended to be lowered along with the reduction itself. - - // Build up a set of partial reduction ops for efficient use checking. - SmallPtrSet PartialReductionOps; - for (const auto &[PartialRdx, _] : PartialReductionChains) - PartialReductionOps.insert(PartialRdx.ExtendUser); - - auto ExtendIsOnlyUsedByPartialReductions = - [&PartialReductionOps](Instruction *Extend) { - return all_of(Extend->users(), [&](const User *U) { - return PartialReductionOps.contains(U); - }); - }; - - // Check if each use of a chain's two extends is a partial reduction - // and only add those that don't have non-partial reduction users. - for (auto Pair : PartialReductionChains) { - PartialReductionChain Chain = Pair.first; - if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) && - (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))) - ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second); - } - - // Check that all partial reductions in a chain are only used by other - // partial reductions with the same scale factor. Otherwise we end up creating - // users of scaled reductions where the types of the other operands don't - // match. - for (const auto &[Chain, Scale] : PartialReductionChains) { - auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) { - auto *UI = cast(U); - if (isa(UI) && UI->getParent() == OrigLoop->getHeader()) { - return all_of(UI->users(), [ScaleVal, this](const User *U) { - auto *UI = cast(U); - return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal; - }); - } - return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal || - !OrigLoop->contains(UI->getParent()); - }; - if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) - ScaledReductionMap.erase(Chain.Reduction); - } -} - -bool VPRecipeBuilder::getScaledReductions( - Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range, - SmallVectorImpl> &Chains) { - if (!CM.TheLoop->contains(RdxExitInstr)) - return false; - - auto *Update = dyn_cast(RdxExitInstr); - if (!Update) - return false; - - Value *Op = Update->getOperand(0); - Value *PhiOp = Update->getOperand(1); - if (Op == PHI) - std::swap(Op, PhiOp); - - // Try and get a scaled reduction from the first non-phi operand. - // If one is found, we use the discovered reduction instruction in - // place of the accumulator for costing. - if (auto *OpInst = dyn_cast(Op)) { - if (getScaledReductions(PHI, OpInst, Range, Chains)) { - PHI = Chains.rbegin()->first.Reduction; - - Op = Update->getOperand(0); - PhiOp = Update->getOperand(1); - if (Op == PHI) - std::swap(Op, PhiOp); - } - } - if (PhiOp != PHI) - return false; - - using namespace llvm::PatternMatch; - - // If the update is a binary operator, check both of its operands to see if - // they are extends. Otherwise, see if the update comes directly from an - // extend. - Instruction *Exts[2] = {nullptr}; - BinaryOperator *ExtendUser = dyn_cast(Op); - std::optional BinOpc; - Type *ExtOpTypes[2] = {nullptr}; - TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None}; - - auto CollectExtInfo = [this, &Exts, &ExtOpTypes, - &ExtKinds](SmallVectorImpl &Ops) -> bool { - for (const auto &[I, OpI] : enumerate(Ops)) { - const APInt *C; - if (I > 0 && match(OpI, m_APInt(C)) && - canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) { - ExtOpTypes[I] = ExtOpTypes[0]; - ExtKinds[I] = ExtKinds[0]; - continue; - } - Value *ExtOp; - if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp)))) - return false; - Exts[I] = cast(OpI); - - // TODO: We should be able to support live-ins. - if (!CM.TheLoop->contains(Exts[I])) - return false; - - ExtOpTypes[I] = ExtOp->getType(); - ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]); - } - return true; - }; - - if (ExtendUser) { - if (!ExtendUser->hasOneUse()) - return false; - - // Use the side-effect of match to replace BinOp only if the pattern is - // matched, we don't care at this point whether it actually matched. - match(ExtendUser, m_Neg(m_BinOp(ExtendUser))); - - SmallVector Ops(ExtendUser->operands()); - if (!CollectExtInfo(Ops)) - return false; - - BinOpc = std::make_optional(ExtendUser->getOpcode()); - } else if (match(Update, m_Add(m_Value(), m_Value()))) { - // We already know the operands for Update are Op and PhiOp. - SmallVector Ops({Op}); - if (!CollectExtInfo(Ops)) - return false; - - ExtendUser = Update; - BinOpc = std::nullopt; - } else - return false; - - PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser); - - TypeSize PHISize = PHI->getType()->getPrimitiveSizeInBits(); - TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits(); - if (!PHISize.hasKnownScalarFactor(ASize)) - return false; - unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize); - - if (LoopVectorizationPlanner::getDecisionAndClampRange( - [&](ElementCount VF) { - InstructionCost Cost = TTI->getPartialReductionCost( - Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1], - PHI->getType(), VF, ExtKinds[0], ExtKinds[1], BinOpc, - CM.CostKind); - return Cost.isValid(); - }, - Range)) { - Chains.emplace_back(Chain, TargetScaleFactor); - return true; - } - - return false; -} - VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R, VFRange &Range) { // First, check for specific widening recipes that deal with inductions, Phi @@ -8183,12 +8011,11 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R, assert(RdxDesc.getRecurrenceStartValue() == Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())); - // If the PHI is used by a partial reduction, set the scale factor. - unsigned ScaleFactor = - getScalingForReduction(RdxDesc.getLoopExitInstr()).value_or(1); - PhiRecipe = new VPReductionPHIRecipe( - Phi, RdxDesc.getRecurrenceKind(), *StartV, CM.isInLoopReduction(Phi), - CM.useOrderedReductions(RdxDesc), ScaleFactor); + // Always create with scale factor 1. Partial reductions will be created + // later in createPartialReductions transform. + PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc.getRecurrenceKind(), + *StartV, CM.isInLoopReduction(Phi), + CM.useOrderedReductions(RdxDesc)); } else { // TODO: Currently fixed-order recurrences are modeled as chains of // first-order recurrences. If there are no users of the intermediate @@ -8224,9 +8051,6 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R, VPI->getOpcode() == Instruction::Store) return tryToWidenMemory(VPI, Range); - if (std::optional ScaleFactor = getScalingForReduction(Instr)) - return tryToCreatePartialReduction(VPI, ScaleFactor.value()); - if (!shouldWiden(Instr, Range)) return nullptr; @@ -8247,41 +8071,6 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R, return tryToWiden(VPI); } -VPRecipeBase * -VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction, - unsigned ScaleFactor) { - assert(Reduction->getNumOperands() == 2 && - "Unexpected number of operands for partial reduction"); - - VPValue *BinOp = Reduction->getOperand(0); - VPValue *Accumulator = Reduction->getOperand(1); - if (isa(BinOp) || isa(BinOp)) - std::swap(BinOp, Accumulator); - - assert(ScaleFactor == - vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) && - "all accumulators in chain must have same scale factor"); - - unsigned ReductionOpcode = Reduction->getOpcode(); - auto *ReductionI = Reduction->getUnderlyingInstr(); - if (ReductionOpcode == Instruction::Sub) { - auto *const Zero = ConstantInt::get(ReductionI->getType(), 0); - SmallVector Ops; - Ops.push_back(Plan.getOrAddLiveIn(Zero)); - Ops.push_back(BinOp); - BinOp = new VPWidenRecipe(*ReductionI, Ops, VPIRMetadata(), - ReductionI->getDebugLoc()); - Builder.insert(BinOp->getDefiningRecipe()); - ReductionOpcode = Instruction::Add; - } - - VPValue *Cond = nullptr; - if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent())) - Cond = getBlockInMask(Builder.getInsertBlock()); - return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond, - ScaleFactor, ReductionI); -} - void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, ElementCount MaxVF) { if (ElementCount::isKnownGT(MinVF, MaxVF)) @@ -8408,11 +8197,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // Construct wide recipes and apply predication for original scalar // VPInstructions in the loop. // --------------------------------------------------------------------------- - VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE, - Builder, BlockMaskCache); - // TODO: Handle partial reductions with EVL tail folding. - if (!CM.foldTailWithEVL()) - RecipeBuilder.collectScaledReductions(Range); + VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder, + BlockMaskCache); // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. @@ -8521,11 +8307,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( *Plan)) return nullptr; - // Transform recipes to abstract recipes if it is legal and beneficial and - // clamp the range for better cost estimation. - // TODO: Enable following transform when the EVL-version of extended-reduction - // and mulacc-reduction are implemented. if (!CM.foldTailWithEVL()) { + // Create partial reduction recipes for scaled reductions. + VPlanTransforms::createPartialReductions(*Plan, Range, &TTI, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, *CM.PSE.getSE(), OrigLoop); VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan, @@ -8606,8 +8391,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) { // Collect mapping of IR header phis to header phi recipes, to be used in // addScalarResumePhis. DenseMap BlockMaskCache; - VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE, - Builder, BlockMaskCache); + VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder, + BlockMaskCache); for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { if (isa(&R)) continue; @@ -8957,11 +8742,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VPBuilder PHBuilder(Plan->getVectorPreheader()); VPValue *Iden = Plan->getOrAddLiveIn( getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags())); - // If the PHI is used by a partial reduction, set the scale factor. - unsigned ScaleFactor = - RecipeBuilder.getScalingForReduction(RdxDesc.getLoopExitInstr()) - .value_or(1); - auto *ScaleFactorVPV = Plan->getConstantInt(32, ScaleFactor); + auto *ScaleFactorVPV = Plan->getConstantInt(32, 1); VPValue *StartV = PHBuilder.createNaryOp( VPInstruction::ReductionStartVector, {PhiR->getStartValue(), Iden, ScaleFactorVPV}, diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 87280b83fc0e5..e901ffd38d422 100644 --- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -19,30 +19,9 @@ namespace llvm { class LoopVectorizationLegality; class LoopVectorizationCostModel; class TargetLibraryInfo; -class TargetTransformInfo; struct HistogramInfo; struct VFRange; -/// A chain of instructions that form a partial reduction. -/// Designed to match either: -/// reduction_bin_op (extend (A), accumulator), or -/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator). -struct PartialReductionChain { - PartialReductionChain(Instruction *Reduction, Instruction *ExtendA, - Instruction *ExtendB, Instruction *ExtendUser) - : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), - ExtendUser(ExtendUser) {} - /// The top-level binary operation that forms the reduction to a scalar - /// after the loop body. - Instruction *Reduction; - /// The extension of each of the inner binary operation's operands. - Instruction *ExtendA; - Instruction *ExtendB; - - /// The user of the extends that is then reduced. - Instruction *ExtendUser; -}; - /// Helper class to create VPRecipies from IR instructions. class VPRecipeBuilder { /// The VPlan new recipes are added to. @@ -54,9 +33,6 @@ class VPRecipeBuilder { /// Target Library Info. const TargetLibraryInfo *TLI; - // Target Transform Info. - const TargetTransformInfo *TTI; - /// The legality analysis. LoopVectorizationLegality *Legal; @@ -81,9 +57,6 @@ class VPRecipeBuilder { /// created. SmallVector PhisToFix; - /// A mapping of partial reduction exit instructions to their scaling factor. - DenseMap ScaledReductionMap; - /// Check if \p I can be widened at the start of \p Range and possibly /// decrease the range such that the returned value holds for the entire \p /// Range. The function should not be called for memory instructions or calls. @@ -121,48 +94,19 @@ class VPRecipeBuilder { VPHistogramRecipe *tryToWidenHistogram(const HistogramInfo *HI, VPInstruction *VPI); - /// Examines reduction operations to see if the target can use a cheaper - /// operation with a wider per-iteration input VF and narrower PHI VF. - /// Each element within Chains is a pair with a struct containing reduction - /// information and the scaling factor between the number of elements in - /// the input and output. - /// Recursively calls itself to identify chained scaled reductions. - /// Returns true if this invocation added an entry to Chains, otherwise false. - /// i.e. returns false in the case that a subcall adds an entry to Chains, - /// but the top-level call does not. - bool getScaledReductions( - Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range, - SmallVectorImpl> &Chains); - public: VPRecipeBuilder(VPlan &Plan, Loop *OrigLoop, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, LoopVectorizationLegality *Legal, LoopVectorizationCostModel &CM, PredicatedScalarEvolution &PSE, VPBuilder &Builder, DenseMap &BlockMaskCache) - : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), TTI(TTI), Legal(Legal), - CM(CM), PSE(PSE), Builder(Builder), BlockMaskCache(BlockMaskCache) {} - - std::optional getScalingForReduction(const Instruction *ExitInst) { - auto It = ScaledReductionMap.find(ExitInst); - return It == ScaledReductionMap.end() ? std::nullopt - : std::make_optional(It->second); - } - - /// Find all possible partial reductions in the loop and track all of those - /// that are valid so recipes can be formed later. - void collectScaledReductions(VFRange &Range); + : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM), + PSE(PSE), Builder(Builder), BlockMaskCache(BlockMaskCache) {} /// Create and return a widened recipe for \p R if one can be created within /// the given VF \p Range. VPRecipeBase *tryToCreateWidenRecipe(VPSingleDefRecipe *R, VFRange &Range); - /// Create and return a partial reduction recipe for a reduction instruction - /// along with binary operation and reduction phi operands. - VPRecipeBase *tryToCreatePartialReduction(VPInstruction *Reduction, - unsigned ScaleFactor); - /// Set the recipe created for given ingredient. void setRecipe(Instruction *I, VPRecipeBase *R) { assert(!Ingredient2Recipe.contains(I) && diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index c81834e401726..72f932c6c8c7d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2408,6 +2408,9 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe, /// Get the factor that the VF of this recipe's output should be scaled by. unsigned getVFScaleFactor() const { return VFScaleFactor; } + /// Set the factor that the VF of this recipe's output should be scaled by. + void setVFScaleFactor(unsigned Factor) { VFScaleFactor = Factor; } + /// Returns the number of incoming values, also number of incoming blocks. /// Note that at the moment, VPWidenPointerInductionRecipe only has a single /// incoming value, its start value. @@ -2835,8 +2838,7 @@ class VPPartialReductionRecipe : public VPReductionRecipe { VPPartialReductionRecipe *clone() override { return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1), - getCondOp(), VFScaleFactor, - getUnderlyingInstr()); + getCondOp(), VFScaleFactor); } VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC) diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index 612202d049774..11265e3f44e8c 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -12,14 +12,19 @@ //===----------------------------------------------------------------------===// #include "LoopVectorizationPlanner.h" +#include "VPRecipeBuilder.h" #include "VPlan.h" +#include "VPlanAnalysis.h" #include "VPlanCFG.h" #include "VPlanDominatorTree.h" +#include "VPlanHelpers.h" #include "VPlanPatternMatch.h" #include "VPlanTransforms.h" +#include "VPlanUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/LoopVersioning.h" @@ -991,3 +996,381 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { MiddleTerm->setOperand(0, NewCond); return true; } + +namespace { +/// A VPlan-based chain of recipes that form a partial reduction. +/// Designed to match either: +/// reduction_bin_op (extend (A), accumulator), or +/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator). +struct VPPartialReductionChain { + VPPartialReductionChain(VPWidenRecipe *Reduction, VPWidenCastRecipe *ExtendA, + VPWidenCastRecipe *ExtendB, VPWidenRecipe *ExtendUser) + : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), + ExtendUser(ExtendUser) {} + /// The top-level binary operation that forms the reduction to a scalar + /// after the loop body. + VPWidenRecipe *Reduction; + /// The extension of each of the inner binary operation's operands. + VPWidenCastRecipe *ExtendA; + VPWidenCastRecipe *ExtendB; + + /// The user of the extends that is then reduced. + VPWidenRecipe *ExtendUser; +}; + +// Helper to transform a single widen recipe into a partial reduction recipe. +// Returns true if transformation succeeded. +static bool transformToPartialReduction(VPWidenRecipe *WidenRecipe, + unsigned ScaleFactor, VPlan &Plan) { + assert(WidenRecipe->getNumOperands() == 2 && "Expected binary operation"); + + VPValue *BinOp = WidenRecipe->getOperand(0); + VPValue *Accumulator = WidenRecipe->getOperand(1); + + // Swap if needed to ensure Accumulator is the PHI or partial reduction. + VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe(); + if (BinOpRecipe && (isa(BinOpRecipe) || + isa(BinOpRecipe))) + std::swap(BinOp, Accumulator); + + // For chained reductions, only transform if accumulator is already a PHI or + // partial reduction. Otherwise, it needs to be transformed first. + VPRecipeBase *AccumRecipe = Accumulator->getDefiningRecipe(); + if (!AccumRecipe || (!isa(AccumRecipe) && + !isa(AccumRecipe))) + return false; + + if (auto *RdxPhi = dyn_cast(AccumRecipe)) { + assert(RdxPhi->getVFScaleFactor() == 1 && "scale factor must not be set"); + RdxPhi->setVFScaleFactor(ScaleFactor); + + // Update ReductionStartVector instruction scale factor. + VPValue *StartValue = RdxPhi->getOperand(0); + auto *StartInst = cast(StartValue); + assert(StartInst->getOpcode() == VPInstruction::ReductionStartVector); + auto *NewScaleFactor = Plan.getConstantInt(32, ScaleFactor); + StartInst->setOperand(2, NewScaleFactor); + } + + // Handle SUB by negating the operand and using ADD for the partial reduction. + unsigned ReductionOpcode = WidenRecipe->getOpcode(); + if (ReductionOpcode == Instruction::Sub) { + VPBuilder Builder(WidenRecipe); + + // Infer the scalar type for creating the zero constant + Type *ElemTy = VPTypeAnalysis(Plan).inferScalarType(BinOp); + auto *Zero = Plan.getConstantInt(ElemTy, 0); + + // Create a negation: 0 - BinOp + VPIRFlags Flags; + if (auto *I = WidenRecipe->getUnderlyingInstr()) + Flags = VPIRFlags(*I); + + auto *NegRecipe = new VPWidenRecipe(Instruction::Sub, {Zero, BinOp}, Flags, + VPIRMetadata(), DebugLoc()); + Builder.insert(NegRecipe); + BinOp = NegRecipe; + ReductionOpcode = Instruction::Add; + } + + VPValue *Cond = nullptr; + VPValue *ExitValue = nullptr; + if (auto *RedPhiR = dyn_cast(Accumulator)) { + ExitValue = findComputeReductionResult(RedPhiR)->getOperand(1); + match(ExitValue, m_Select(m_VPValue(Cond), m_VPValue(), m_VPValue())); + } + + auto *PartialRed = new VPPartialReductionRecipe( + ReductionOpcode, Accumulator, BinOp, Cond, ScaleFactor, + WidenRecipe->getUnderlyingInstr()); + PartialRed->insertBefore(WidenRecipe); + + // Remove the select recipe if we extracted it, as the + // VPPartialReductionRecipe now handles the predication. + if (Cond) { + ExitValue->replaceAllUsesWith(PartialRed); + ExitValue->getDefiningRecipe()->eraseFromParent(); + } + WidenRecipe->replaceAllUsesWith(PartialRed); + WidenRecipe->eraseFromParent(); + + return true; +} + +// Helper to validate that extend users are only partial reduction operations. +static bool validateExtendUsers( + VPWidenCastRecipe *Extend, + const SmallPtrSetImpl &PartialReductionOps) { + return !Extend || all_of(Extend->users(), [&](VPUser *U) { + auto *R = dyn_cast(U); + return !R || PartialReductionOps.contains(R); + }); +} + +// Helper to collect extension information from operands for partial reductions. +// Returns true if at least one cast recipe was found. +static bool collectExtensionInfo( + ArrayRef Operands, VPTypeAnalysis &TypeInfo, + VPWidenCastRecipe *CastRecipes[2], Type *ExtOpTypes[2], + TargetTransformInfo::PartialReductionExtendKind ExtKinds[2]) { + if (Operands.size() > 2) + return false; + + for (const auto &[I, OpVal] : enumerate(Operands)) { + // Check for constant that can be extended. + if (I > 0 && ExtKinds[0] != TTI::PR_None) { + const APInt *C; + if (match(OpVal, m_APInt(C)) && + canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) { + ExtOpTypes[I] = ExtOpTypes[0]; + ExtKinds[I] = ExtKinds[0]; + continue; + } + } + + auto *CastRecipe = dyn_cast(OpVal); + if (!CastRecipe) + return false; + + Instruction::CastOps CastOp = + static_cast(CastRecipe->getOpcode()); + if (CastOp != Instruction::SExt && CastOp != Instruction::ZExt) + return false; + + CastRecipes[I] = CastRecipe; + ExtOpTypes[I] = TypeInfo.inferScalarType(CastRecipe->getOperand(0)); + ExtKinds[I] = TTI::getPartialReductionExtendKind(CastOp); + } + return CastRecipes[0] != nullptr; +} + +// VPlan-based pattern detection for partial reductions. +static bool getScaledReductions( + VPReductionPHIRecipe *RedPhiR, VPValue *PrevValue, VFRange &Range, + SmallVectorImpl> &Chains, + VPlan &Plan, const TargetTransformInfo *TTI, + TargetTransformInfo::TargetCostKind CostKind) { + auto *UpdateRecipe = dyn_cast(PrevValue); + if (!UpdateRecipe || UpdateRecipe->getNumOperands() != 2) + return false; + + VPTypeAnalysis TypeInfo(Plan); + // Helper to finalize and validate chain. + auto FinalizeChain = [&](VPWidenRecipe *ExtendUser, Type *ExtOpTypes[2], + TTI::PartialReductionExtendKind ExtKinds[2], + VPWidenCastRecipe *CastRecipes[2], + std::optional BinOpc) { + Type *PhiType = TypeInfo.inferScalarType(RedPhiR); + TypeSize PHISize = PhiType->getPrimitiveSizeInBits(); + TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits(); + if (!PHISize.hasKnownScalarFactor(ASize)) + return false; + if (LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) { + return TTI + ->getPartialReductionCost( + UpdateRecipe->getOpcode(), ExtOpTypes[0], ExtOpTypes[1], + PhiType, VF, ExtKinds[0], ExtKinds[1], BinOpc, CostKind) + .isValid(); + }, + Range)) { + unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize); + VPPartialReductionChain Chain(UpdateRecipe, CastRecipes[0], + CastRecipes[1], ExtendUser); + Chains.emplace_back(Chain, TargetScaleFactor); + return true; + } + return false; + }; + + VPValue *Op = UpdateRecipe->getOperand(0); + VPValue *PhiOp = UpdateRecipe->getOperand(1); + if (Op == RedPhiR) + std::swap(Op, PhiOp); + + bool FoundChainedReduction = false; + if (isa(Op)) { + FoundChainedReduction = + getScaledReductions(RedPhiR, Op, Range, Chains, Plan, TTI, CostKind); + } + + if (!FoundChainedReduction && PhiOp != RedPhiR) + return false; + + if (FoundChainedReduction) { + VPValue *ExtendedVal = PhiOp; + + if (UpdateRecipe->getOpcode() != Instruction::Add && + UpdateRecipe->getOpcode() != Instruction::Sub) + return false; + + VPWidenCastRecipe *CastRecipes[2] = {nullptr}; + Type *ExtOpTypes[2] = {nullptr}; + TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None}; + VPWidenRecipe *ExtendUser = nullptr; + std::optional BinOpc; + + if (auto *CastRecipe = dyn_cast(ExtendedVal)) { + Instruction::CastOps CastOp = + static_cast(CastRecipe->getOpcode()); + if (CastOp == Instruction::SExt || CastOp == Instruction::ZExt) { + ExtOpTypes[0] = TypeInfo.inferScalarType(CastRecipe->getOperand(0)); + ExtKinds[0] = TTI::getPartialReductionExtendKind(CastOp); + CastRecipes[0] = CastRecipe; + ExtendUser = UpdateRecipe; + } + } + + if (!CastRecipes[0]) { + if (auto *BinOpRecipe = dyn_cast(ExtendedVal)) { + unsigned ExtUserOpcode = BinOpRecipe->getOpcode(); + if (ExtUserOpcode == Instruction::Mul || + ExtUserOpcode == Instruction::Add || + ExtUserOpcode == Instruction::Sub) { + SmallVector BinOpOperands(BinOpRecipe->operands()); + if (collectExtensionInfo(BinOpOperands, TypeInfo, CastRecipes, + ExtOpTypes, ExtKinds)) { + ExtendUser = BinOpRecipe; + BinOpc = ExtUserOpcode; + } + } + } + + if (!CastRecipes[0]) + return false; + } + + return FinalizeChain(ExtendUser, ExtOpTypes, ExtKinds, CastRecipes, BinOpc); + } + + // Check if Op comes from a binary operation with extended operands. + std::optional BinOpc; + VPWidenCastRecipe *CastRecipes[2] = {nullptr}; + Type *ExtOpTypes[2] = {nullptr}; + TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None}; + + VPWidenRecipe *ExtendUser = dyn_cast(Op); + if (ExtendUser) { + unsigned ExtUserOpcode = ExtendUser->getOpcode(); + if (!Instruction::isBinaryOp(ExtUserOpcode) || !ExtendUser->hasOneUse()) + return false; + + SmallVector Operands(ExtendUser->operands()); + + // Check for negation pattern + VPValue *OtherOp; + if (match(ExtendUser, m_Sub(m_ZeroInt(), m_VPValue(OtherOp))) && + isa(OtherOp)) { + ExtendUser = cast(OtherOp); + ExtUserOpcode = ExtendUser->getOpcode(); + auto OpRange = ExtendUser->operands(); + Operands.assign(OpRange.begin(), OpRange.end()); + } + + if (!collectExtensionInfo(Operands, TypeInfo, CastRecipes, ExtOpTypes, + ExtKinds)) + return false; + + BinOpc = std::make_optional(ExtUserOpcode); + } else if (UpdateRecipe->getOpcode() == Instruction::Add) { + if (!collectExtensionInfo({Op}, TypeInfo, CastRecipes, ExtOpTypes, + ExtKinds)) + return false; + + ExtendUser = UpdateRecipe; + BinOpc = std::nullopt; + } else + return false; + + return FinalizeChain(ExtendUser, ExtOpTypes, ExtKinds, CastRecipes, BinOpc); +} +} // namespace + +void VPlanTransforms::createPartialReductions( + VPlan &Plan, VFRange &Range, const TargetTransformInfo *TTI, + TargetTransformInfo::TargetCostKind CostKind) { + // Collect all partial reduction chains. + SmallVector> + AllPartialReductionChains; + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + for (VPRecipeBase &R : *HeaderVPBB) { + auto *RedPhiR = dyn_cast(&R); + if (!RedPhiR) + continue; + + VPValue *ExitValue = nullptr; + if (auto *RdxResult = findComputeReductionResult(RedPhiR)) + ExitValue = RdxResult->getOperand(1); + else + continue; + + match(ExitValue, m_Select(m_VPValue(), m_VPValue(ExitValue), m_VPValue())); + getScaledReductions(RedPhiR, ExitValue, Range, AllPartialReductionChains, + Plan, TTI, CostKind); + } + + if (AllPartialReductionChains.empty()) + return; + + // Build set of all reduction operations for usage validation. + SmallPtrSet PartialReductionOps; + for (const auto &[Chain, _] : AllPartialReductionChains) { + PartialReductionOps.insert(Chain.ExtendUser); + PartialReductionOps.insert(Chain.Reduction); + } + + // Validate extends and build map of valid reductions. + DenseMap ScaledReductionRecipeMap; + for (const auto &[Chain, ScaleFactor] : AllPartialReductionChains) { + if (validateExtendUsers(Chain.ExtendA, PartialReductionOps) && + validateExtendUsers(Chain.ExtendB, PartialReductionOps)) { + ScaledReductionRecipeMap.try_emplace(Chain.Reduction, ScaleFactor); + } + } + + // Validate that reductions are only used by other reductions with the same + // scale factor. + VPRegionBlock *VectorLoopRegion = Plan.getVectorLoopRegion(); + for (const auto &[Chain, ScaleFactor] : AllPartialReductionChains) { + unsigned Scale = ScaleFactor; + auto AllUsersPartialRdx = [Scale, &ScaledReductionRecipeMap, + VectorLoopRegion, HeaderVPBB](VPUser *U) { + if (auto *RedPhiR = dyn_cast(U)) { + if (RedPhiR->getParent() != HeaderVPBB) + return true; + return all_of(RedPhiR->users(), + [Scale, &ScaledReductionRecipeMap](VPUser *PhiUser) { + auto *WidenRecipe = dyn_cast(PhiUser); + return !WidenRecipe || ScaledReductionRecipeMap.lookup( + WidenRecipe) == Scale; + }); + } + + if (auto *WidenRecipe = dyn_cast(U)) { + bool InLoop = WidenRecipe->getRegion() == VectorLoopRegion; + return !InLoop || ScaledReductionRecipeMap.lookup(WidenRecipe) == Scale; + } + + return true; + }; + + if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) + ScaledReductionRecipeMap.erase(Chain.Reduction); + } + + // Transform validated reductions. Use iterative approach to handle chains. + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly( + vp_depth_first_shallow(VectorLoopRegion->getEntry()))) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + auto *WidenRecipe = dyn_cast(&R); + if (!WidenRecipe) + continue; + + unsigned ScaleFactor = ScaledReductionRecipeMap.lookup(WidenRecipe); + if (!ScaleFactor) + continue; + + transformToPartialReduction(WidenRecipe, ScaleFactor, Plan); + } + } +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index fca6554ad77c6..4d7bd7d7e04b6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -400,7 +400,7 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, void VPPartialReductionRecipe::execute(VPTransformState &State) { auto &Builder = State.Builder; - assert(getOpcode() == Instruction::Add && + assert((getOpcode() == Instruction::Add || getOpcode() == Instruction::Sub) && "Unhandled partial reduction opcode"); Value *BinOpVal = State.get(getVecOp()); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index a44a4f69c917b..9de1d0ce5ab90 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -16,6 +16,7 @@ #include "VPlan.h" #include "VPlanVerifier.h" #include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -28,6 +29,7 @@ class PHINode; class ScalarEvolution; class PredicatedScalarEvolution; class TargetLibraryInfo; +class TargetTransformInfo; class VPBuilder; class VPRecipeBuilder; struct VFRange; @@ -379,6 +381,14 @@ struct VPlanTransforms { /// users in the original exit block using the VPIRInstruction wrapping to the /// LCSSA phi. static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range); + + /// Detect and create partial reduction recipes for scaled reductions in + /// \p Plan. Must be called after recipe construction. If partial reductions + /// are only valid for a subset of VFs in Range, Range.End is updated. + static void + createPartialReductions(VPlan &Plan, VFRange &Range, + const TargetTransformInfo *TTI, + TargetTransformInfo::TargetCostKind CostKind); }; } // namespace llvm