Skip to content

Conversation

@sdesmalen-arm
Copy link
Collaborator

This means that VPExpressions will now be constructed for
VPPartialReductionRecipe's when the loop has tail-folding predication.

Note that control-flow (if/else) predication is not yet handled
for partial reductions, because of the way partial reductions
are recognised and built up.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR moves the conditional selection logic from VPlan construction time to VPlan execution time for partial reductions. When tail-folding predication is enabled, VPExpressions are now constructed for VPPartialReductionRecipe's, improving the VPlan's expressiveness.

Key changes:

  • Conditional selection (select instruction) for predicated partial reductions is now created during execution in VPPartialReductionRecipe::execute() instead of during VPlan construction
  • VPExpressions are now properly constructed for partial reductions with tail-folding predication
  • Test outputs updated to reflect the new VPlan structure and code generation order

Reviewed Changes

Copilot reviewed 5 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp Removes conditional selection logic from VPlan construction phase
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp Adds conditional selection logic to execute() method
llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll Adds test case verifying VPExpressions with predication
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll Updates test output reflecting reordered instruction generation
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll Updates test output with restructured predicated load patterns

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@llvmbot
Copy link
Member

llvmbot commented Nov 3, 2025

@llvm/pr-subscribers-vectorizers

Author: Sander de Smalen (sdesmalen-arm)

Changes

This means that VPExpressions will now be constructed for
VPPartialReductionRecipe's when the loop has tail-folding predication.

Note that control-flow (if/else) predication is not yet handled
for partial reductions, because of the way partial reductions
are recognised and built up.


Patch is 138.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166136.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+1-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+8-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+81-161)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+243-483)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+1-1)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+71)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e5c3f17860103..8ec11be3037a2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8171,15 +8171,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
   }
 
   VPValue *Cond = nullptr;
-  if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
-    assert((ReductionOpcode == Instruction::Add ||
-            ReductionOpcode == Instruction::Sub) &&
-           "Expected an ADD or SUB operation for predicated partial "
-           "reductions (because the neutral element in the mask is zero)!");
+  if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent()))
     Cond = getBlockInMask(Builder.getInsertBlock());
-    VPValue *Zero = Plan.getConstantInt(Reduction->getType(), 0);
-    BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
-  }
   return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
                                       ScaleFactor, Reduction);
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1ee405a62aa68..1677d3f522b80 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -395,12 +395,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   assert(getOpcode() == Instruction::Add &&
          "Unhandled partial reduction opcode");
 
-  Value *BinOpVal = State.get(getOperand(1));
-  Value *PhiVal = State.get(getOperand(0));
+  Value *BinOpVal = State.get(getVecOp());
+  Value *PhiVal = State.get(getChainOp());
   assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
   Type *RetTy = PhiVal->getType();
 
+  if (isConditional()) {
+    Value *Cond = State.get(getCondOp());
+    Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
+    BinOpVal = Builder.CreateSelect(Cond, BinOpVal, Zero);
+  }
+
   CallInst *V =
       Builder.CreateIntrinsic(RetTy, Intrinsic::vector_partial_reduce_add,
                               {PhiVal, BinOpVal}, nullptr, "partial.reduce");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
index d8f1a86c9ebda..5b9bd0997f2fa 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
@@ -182,313 +182,233 @@ define i32 @dotp_predicated(i64 %N, ptr %a, ptr %b) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_LOAD_CONTINUE62:%.*]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <16 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[PRED_LOAD_CONTINUE62]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[PRED_LOAD_CONTINUE62]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
-; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
-; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
-; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
-; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
-; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
-; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
 ; CHECK-NEXT:    [[TMP16:%.*]] = icmp ule <16 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i1> [[TMP16]], i32 0
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP19:%.*]] = load i8, ptr [[TMP18]], align 1
 ; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <16 x i8> poison, i8 [[TMP19]], i32 0
+; CHECK-NEXT:    [[TMP99:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP101:%.*]] = load i8, ptr [[TMP99]], align 1
+; CHECK-NEXT:    [[TMP102:%.*]] = insertelement <16 x i8> poison, i8 [[TMP101]], i32 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
 ; CHECK-NEXT:    [[TMP21:%.*]] = phi <16 x i8> [ poison, [[VECTOR_BODY]] ], [ [[TMP20]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP103:%.*]] = phi <16 x i8> [ poison, [[VECTOR_BODY]] ], [ [[TMP102]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP22:%.*]] = extractelement <16 x i1> [[TMP16]], i32 1
 ; CHECK-NEXT:    br i1 [[TMP22]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
 ; CHECK:       pred.load.if1:
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP1]]
 ; CHECK-NEXT:    [[TMP24:%.*]] = load i8, ptr [[TMP23]], align 1
 ; CHECK-NEXT:    [[TMP25:%.*]] = insertelement <16 x i8> [[TMP21]], i8 [[TMP24]], i32 1
+; CHECK-NEXT:    [[TMP104:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP105:%.*]] = load i8, ptr [[TMP104]], align 1
+; CHECK-NEXT:    [[TMP109:%.*]] = insertelement <16 x i8> [[TMP103]], i8 [[TMP105]], i32 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE2]]
 ; CHECK:       pred.load.continue2:
 ; CHECK-NEXT:    [[TMP26:%.*]] = phi <16 x i8> [ [[TMP21]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP25]], [[PRED_LOAD_IF1]] ]
+; CHECK-NEXT:    [[TMP111:%.*]] = phi <16 x i8> [ [[TMP103]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP109]], [[PRED_LOAD_IF1]] ]
 ; CHECK-NEXT:    [[TMP27:%.*]] = extractelement <16 x i1> [[TMP16]], i32 2
 ; CHECK-NEXT:    br i1 [[TMP27]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
+; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
 ; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP29:%.*]] = load i8, ptr [[TMP28]], align 1
 ; CHECK-NEXT:    [[TMP30:%.*]] = insertelement <16 x i8> [[TMP26]], i8 [[TMP29]], i32 2
+; CHECK-NEXT:    [[TMP112:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP113:%.*]] = load i8, ptr [[TMP112]], align 1
+; CHECK-NEXT:    [[TMP114:%.*]] = insertelement <16 x i8> [[TMP111]], i8 [[TMP113]], i32 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
 ; CHECK-NEXT:    [[TMP31:%.*]] = phi <16 x i8> [ [[TMP26]], [[PRED_LOAD_CONTINUE2]] ], [ [[TMP30]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP115:%.*]] = phi <16 x i8> [ [[TMP111]], [[PRED_LOAD_CONTINUE2]] ], [ [[TMP114]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP32:%.*]] = extractelement <16 x i1> [[TMP16]], i32 3
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
+; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
 ; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[TMP34:%.*]] = load i8, ptr [[TMP33]], align 1
 ; CHECK-NEXT:    [[TMP35:%.*]] = insertelement <16 x i8> [[TMP31]], i8 [[TMP34]], i32 3
+; CHECK-NEXT:    [[TMP119:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP121:%.*]] = load i8, ptr [[TMP119]], align 1
+; CHECK-NEXT:    [[TMP122:%.*]] = insertelement <16 x i8> [[TMP115]], i8 [[TMP121]], i32 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
 ; CHECK-NEXT:    [[TMP36:%.*]] = phi <16 x i8> [ [[TMP31]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP35]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP123:%.*]] = phi <16 x i8> [ [[TMP115]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP122]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP37:%.*]] = extractelement <16 x i1> [[TMP16]], i32 4
 ; CHECK-NEXT:    br i1 [[TMP37]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8:%.*]]
 ; CHECK:       pred.load.if7:
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP39:%.*]] = load i8, ptr [[TMP38]], align 1
 ; CHECK-NEXT:    [[TMP40:%.*]] = insertelement <16 x i8> [[TMP36]], i8 [[TMP39]], i32 4
+; CHECK-NEXT:    [[TMP124:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP125:%.*]] = load i8, ptr [[TMP124]], align 1
+; CHECK-NEXT:    [[TMP129:%.*]] = insertelement <16 x i8> [[TMP123]], i8 [[TMP125]], i32 4
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
 ; CHECK-NEXT:    [[TMP41:%.*]] = phi <16 x i8> [ [[TMP36]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP40]], [[PRED_LOAD_IF7]] ]
+; CHECK-NEXT:    [[TMP131:%.*]] = phi <16 x i8> [ [[TMP123]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP129]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <16 x i1> [[TMP16]], i32 5
 ; CHECK-NEXT:    br i1 [[TMP42]], label [[PRED_LOAD_IF9:%.*]], label [[PRED_LOAD_CONTINUE10:%.*]]
 ; CHECK:       pred.load.if9:
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
 ; CHECK-NEXT:    [[TMP43:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP5]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = load i8, ptr [[TMP43]], align 1
 ; CHECK-NEXT:    [[TMP45:%.*]] = insertelement <16 x i8> [[TMP41]], i8 [[TMP44]], i32 5
+; CHECK-NEXT:    [[TMP132:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP133:%.*]] = load i8, ptr [[TMP132]], align 1
+; CHECK-NEXT:    [[TMP134:%.*]] = insertelement <16 x i8> [[TMP131]], i8 [[TMP133]], i32 5
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE10]]
 ; CHECK:       pred.load.continue10:
 ; CHECK-NEXT:    [[TMP46:%.*]] = phi <16 x i8> [ [[TMP41]], [[PRED_LOAD_CONTINUE8]] ], [ [[TMP45]], [[PRED_LOAD_IF9]] ]
+; CHECK-NEXT:    [[TMP135:%.*]] = phi <16 x i8> [ [[TMP131]], [[PRED_LOAD_CONTINUE8]] ], [ [[TMP134]], [[PRED_LOAD_IF9]] ]
 ; CHECK-NEXT:    [[TMP47:%.*]] = extractelement <16 x i1> [[TMP16]], i32 6
 ; CHECK-NEXT:    br i1 [[TMP47]], label [[PRED_LOAD_IF11:%.*]], label [[PRED_LOAD_CONTINUE12:%.*]]
 ; CHECK:       pred.load.if11:
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
 ; CHECK-NEXT:    [[TMP48:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-NEXT:    [[TMP49:%.*]] = load i8, ptr [[TMP48]], align 1
 ; CHECK-NEXT:    [[TMP50:%.*]] = insertelement <16 x i8> [[TMP46]], i8 [[TMP49]], i32 6
+; CHECK-NEXT:    [[TMP139:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP141:%.*]] = load i8, ptr [[TMP139]], align 1
+; CHECK-NEXT:    [[TMP142:%.*]] = insertelement <16 x i8> [[TMP135]], i8 [[TMP141]], i32 6
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE12]]
 ; CHECK:       pred.load.continue12:
 ; CHECK-NEXT:    [[TMP51:%.*]] = phi <16 x i8> [ [[TMP46]], [[PRED_LOAD_CONTINUE10]] ], [ [[TMP50]], [[PRED_LOAD_IF11]] ]
+; CHECK-NEXT:    [[TMP143:%.*]] = phi <16 x i8> [ [[TMP135]], [[PRED_LOAD_CONTINUE10]] ], [ [[TMP142]], [[PRED_LOAD_IF11]] ]
 ; CHECK-NEXT:    [[TMP52:%.*]] = extractelement <16 x i1> [[TMP16]], i32 7
 ; CHECK-NEXT:    br i1 [[TMP52]], label [[PRED_LOAD_IF13:%.*]], label [[PRED_LOAD_CONTINUE14:%.*]]
 ; CHECK:       pred.load.if13:
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
 ; CHECK-NEXT:    [[TMP53:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP7]]
 ; CHECK-NEXT:    [[TMP54:%.*]] = load i8, ptr [[TMP53]], align 1
 ; CHECK-NEXT:    [[TMP55:%.*]] = insertelement <16 x i8> [[TMP51]], i8 [[TMP54]], i32 7
+; CHECK-NEXT:    [[TMP144:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP145:%.*]] = load i8, ptr [[TMP144]], align 1
+; CHECK-NEXT:    [[TMP149:%.*]] = insertelement <16 x i8> [[TMP143]], i8 [[TMP145]], i32 7
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE14]]
 ; CHECK:       pred.load.continue14:
 ; CHECK-NEXT:    [[TMP56:%.*]] = phi <16 x i8> [ [[TMP51]], [[PRED_LOAD_CONTINUE12]] ], [ [[TMP55]], [[PRED_LOAD_IF13]] ]
+; CHECK-NEXT:    [[TMP150:%.*]] = phi <16 x i8> [ [[TMP143]], [[PRED_LOAD_CONTINUE12]] ], [ [[TMP149]], [[PRED_LOAD_IF13]] ]
 ; CHECK-NEXT:    [[TMP57:%.*]] = extractelement <16 x i1> [[TMP16]], i32 8
 ; CHECK-NEXT:    br i1 [[TMP57]], label [[PRED_LOAD_IF15:%.*]], label [[PRED_LOAD_CONTINUE16:%.*]]
 ; CHECK:       pred.load.if15:
+; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP58:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP8]]
 ; CHECK-NEXT:    [[TMP59:%.*]] = load i8, ptr [[TMP58]], align 1
 ; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <16 x i8> [[TMP56]], i8 [[TMP59]], i32 8
+; CHECK-NEXT:    [[TMP151:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP152:%.*]] = load i8, ptr [[TMP151]], align 1
+; CHECK-NEXT:    [[TMP153:%.*]] = insertelement <16 x i8> [[TMP150]], i8 [[TMP152]], i32 8
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE16]]
 ; CHECK:       pred.load.continue16:
 ; CHECK-NEXT:    [[TMP61:%.*]] = phi <16 x i8> [ [[TMP56]], [[PRED_LOAD_CONTINUE14]] ], [ [[TMP60]], [[PRED_LOAD_IF15]] ]
+; CHECK-NEXT:    [[TMP154:%.*]] = phi <16 x i8> [ [[TMP150]], [[PRED_LOAD_CONTINUE14]] ], [ [[TMP153]], [[PRED_LOAD_IF15]] ]
 ; CHECK-NEXT:    [[TMP62:%.*]] = extractelement <16 x i1> [[TMP16]], i32 9
 ; CHECK-NEXT:    br i1 [[TMP62]], label [[PRED_LOAD_IF17:%.*]], label [[PRED_LOAD_CONTINUE18:%.*]]
 ; CHECK:       pred.load.if17:
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
 ; CHECK-NEXT:    [[TMP63:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP64:%.*]] = load i8, ptr [[TMP63]], align 1
 ; CHECK-NEXT:    [[TMP65:%.*]] = insertelement <16 x i8> [[TMP61]], i8 [[TMP64]], i32 9
+; CHECK-NEXT:    [[TMP96:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP155:%.*]] = load i8, ptr [[TMP96]], align 1
+; CHECK-NEXT:    [[TMP98:%.*]] = insertelement <16 x i8> [[TMP154]], i8 [[TMP155]], i32 9
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE18]]
 ; CHECK:       pred.load.continue18:
 ; CHECK-NEXT:    [[TMP66:%.*]] = phi <16 x i8> [ [[TMP61]], [[PRED_LOAD_CONTINUE16]] ], [ [[TMP65]], [[PRED_LOAD_IF17]] ]
+; CHECK-NEXT:    [[TMP100:%.*]] = phi <16 x i8> [ [[TMP154]], [[PRED_LOAD_CONTINUE16]] ], [ [[TMP98]], [[PRED_LOAD_IF17]] ]
 ; CHECK-NEXT:    [[TMP67:%.*]] = extractelement <16 x i1> [[TMP16]], i32 10
 ; CHECK-NEXT:    br i1 [[TMP67]], label [[PRED_LOAD_IF19:%.*]], label [[PRED_LOAD_CONTINUE20:%.*]]
 ; CHECK:       pred.load.if19:
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
 ; CHECK-NEXT:    [[TMP68:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP10]]
 ; CHECK-NEXT:    [[TMP69:%.*]] = load i8, ptr [[TMP68]], align 1
 ; CHECK-NEXT:    [[TMP70:%.*]] = insertelement <16 x i8> [[TMP66]], i8 [[TMP69]], i32 10
+; CHECK-NEXT:    [[TMP106:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP10]]
+; CHECK-NEXT:    [[TMP107:%.*]] = load i8, ptr [[TMP106]], align 1
+; CHECK-NEXT:    [[TMP108:%.*]] = insertelement <16 x i8> [[TMP100]], i8 [[TMP107]], i32 10
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE20]]
 ; CHECK:       pred.load.continue20:
 ; CHECK-NEXT:    [[TMP71:%.*]] = phi <16 x i8> [ [[TMP66]], [[PRED_LOAD_CONTINUE18]] ], [ [[TMP70]], [[PRED_LOAD_IF19]] ]
+; CHECK-NEXT:    [[TMP110:%.*]] = phi <16 x i8> [ [[TMP100]], [[PRED_LOAD_CONTINUE18]] ], [ [[TMP108]], [[PRED_LOAD_IF19]] ]
 ; CHECK-NEXT:    [[TMP72:%.*]] = extractelement <16 x i1> [[TMP16]], i32 11
 ; CHECK-NEXT:    br i1 [[TMP72]], label [[PRED_LOAD_IF21:%.*]], label [[PRED_LOAD_CONTINUE22:%.*]]
 ; CHECK:       pred.load.if21:
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
 ; CHECK-NEXT:    [[TMP73:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP74:%.*]] = load i8, ptr [[TMP73]], align 1
 ; CHECK-NEXT:    [[TMP75:%.*]] = insertelement <16 x i8> [[TMP71]], i8 [[TMP74]], i32 11
+; CHECK-NEXT:    [[TMP116:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP117:%.*]] = load i8, ptr [[TMP116]], align 1
+; CHECK-NEXT:    [[TMP118:%.*]] = insertelement <16 x i8> [[TMP110]], i8 [[TMP117]], i32 11
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE22]]
 ; CHECK:       pred.load.continue22:
 ; CHECK-NEXT:    [[TMP76:%.*]] = phi <16 x i8> [ [[TMP71]], [[PRED_LOAD_CONTINUE20]] ], [ [[TMP75]], [[PRED_LOAD_IF21]] ]
+; CHECK-NEXT:    [[TMP120:%.*]] = phi <16 x i8> [ [[TMP110]], [[PRED_LOAD_CONTINUE20]] ], [ [[TMP118]], [[PRED_LOAD_IF21]] ]
 ; CHECK-NEXT:    [[TMP77:%.*]] = extractelement <16 x i1> [[TMP16]], i32 12
 ; CHECK-NEXT:    br i1 [[TMP77]], label [[PRED_LOAD_IF23:%.*]], label [[PRED_LOAD_CONTINUE24:%.*]]
 ; CHECK:       pred.load.if23:
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
 ; CHECK-NEXT:    [[TMP78:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP12]]
 ; CHECK-NEXT:    [[TMP79:%.*]] = load i8, ptr [[TMP78]], align 1
 ; CHECK-NEXT:    [[TMP80:%.*]] = insertelement <16 x i8> [[TMP76]], i8 [[TMP79]], i32 12
+; CHECK-NEXT:    [[TMP126:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP12]]
+; CHECK-NEXT:    [[TMP127:%.*]] = load i8, ptr [[TMP126]], align 1
+; CHECK-NEXT:    [[TMP128:%.*]] = insertelement <16 x i8> [[TMP120]], i8 [[TMP127]], i32 12
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE24]]
 ; CHECK:       pred.load.continue24:
 ; CHECK-NEXT:    [[TMP81:%.*]] = phi <16 x i8> [ [[TMP76]], [[PRED_LOAD_CONTINUE22]] ], [ [[TMP80]], [[PRED_LOAD_IF23]] ]
+; CHECK-NEXT:    [[TMP130:%.*]] = phi <16 x i8> [ [[TMP120]], [[PRED_LOAD_CONTINUE22]] ], [ [[TMP128]], [[PRED_LOAD_IF23]] ]
 ; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <16 x i1> [[TMP16]], i32 13
 ; CHECK-NEXT:    br i1 [[TMP82]], label [[PRED_LOAD_IF25:%.*]], label [[PRED_LOAD_CONTINUE26:%.*]]
 ; CHECK:       pred.load.if25:
+; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
 ; CHECK-NEXT:    [[TMP83:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP13]]
 ; CHECK-NEXT:    [[TMP84:%.*]] = load i8, ptr [[TMP83]], align 1
 ; CHECK-NEXT:    [[TMP85:%.*]] = insertelement <16 x i8> [[TMP81]], i8 [[TMP84]], i32 13
+; CHECK-NEXT:    [[TMP136:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP137:%.*]] = load i8, ptr [[TMP136]], align 1
+; CHECK-NEXT:    [[TMP138:%.*]] = insertelement <16 x i8> [[TMP130]], i8 [[TMP137]], i32 13
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE26]]
 ; CHECK:       pred.load.continue26:
 ; CHECK-NEXT:    [[TMP86:%.*]] = phi <16 x i8> [ [[TMP81]], [[PRED_LOAD_CONTINUE24]] ], [ [[TMP85]], [[PRED_LOAD_IF25]] ]
+; CHECK-NEXT:    [[TMP140:%.*]] = phi <16 x i8> [ [[TMP130]], [[PRED_LOAD_CONTINUE24]] ], [ [[TMP138]], [[PRED_LOAD_IF25]] ]
 ; CHECK-NEXT:    [[TMP87:%.*]] = extractelement <16 x i1> [[TMP16]], i32 14
 ; CHECK-NEXT:    br i1 [[TMP87]], label [[PRED_LOAD_IF27:%.*]], label [[PRED_LOAD_CONTINUE28:%.*]]
 ; CHECK:       pred.load.if27:
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
 ; CHECK-NEXT:    [[TMP88:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP14]]
 ; CHECK-NEXT:    [[TMP89:%.*]] = load i8, ptr [[TMP88]], align 1
 ; CHECK-NEXT:    [[TMP90:%.*]] = insertelement <16 x i8> [[TMP86]], i8 [[TMP89]], i32 14
+; CHECK-NEXT:    [[TMP146:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP147:%.*]] = load i8, ptr [[TMP146]], align 1
+; CHECK-NEXT:    [[TMP148:%.*]] = insertelement <16 x i8> [[TMP140]], i8 [[TMP147]], i32 14
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE28]]
 ; CHECK:       pred.load.continue28:
 ; CHECK-NEXT:    [[TMP91:%.*]] = phi <16 x i8> [ [[TMP86]], [[PRED_LOAD_CONTINUE26]] ], [ [[TMP90]], [[PRED_LOAD_IF27]] ]
+; CHECK-NEXT:    [[TMP172:%.*]] = phi <16 x i8> [ [[TMP140]], [[PRED_LOAD_CONTINUE26]] ], [ [[TMP148]], [[PRED_LOAD_IF27]] ]
 ; CHECK-NEXT:    [[TMP92:%.*]] = extractelement <16 x i1> [[TMP16]], i32 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 3, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Sander de Smalen (sdesmalen-arm)

Changes

This means that VPExpressions will now be constructed for
VPPartialReductionRecipe's when the loop has tail-folding predication.

Note that control-flow (if/else) predication is not yet handled
for partial reductions, because of the way partial reductions
are recognised and built up.


Patch is 138.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166136.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+1-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+8-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+81-161)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+243-483)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+1-1)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+71)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e5c3f17860103..8ec11be3037a2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8171,15 +8171,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
   }
 
   VPValue *Cond = nullptr;
-  if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
-    assert((ReductionOpcode == Instruction::Add ||
-            ReductionOpcode == Instruction::Sub) &&
-           "Expected an ADD or SUB operation for predicated partial "
-           "reductions (because the neutral element in the mask is zero)!");
+  if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent()))
     Cond = getBlockInMask(Builder.getInsertBlock());
-    VPValue *Zero = Plan.getConstantInt(Reduction->getType(), 0);
-    BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
-  }
   return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
                                       ScaleFactor, Reduction);
 }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1ee405a62aa68..1677d3f522b80 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -395,12 +395,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   assert(getOpcode() == Instruction::Add &&
          "Unhandled partial reduction opcode");
 
-  Value *BinOpVal = State.get(getOperand(1));
-  Value *PhiVal = State.get(getOperand(0));
+  Value *BinOpVal = State.get(getVecOp());
+  Value *PhiVal = State.get(getChainOp());
   assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
   Type *RetTy = PhiVal->getType();
 
+  if (isConditional()) {
+    Value *Cond = State.get(getCondOp());
+    Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
+    BinOpVal = Builder.CreateSelect(Cond, BinOpVal, Zero);
+  }
+
   CallInst *V =
       Builder.CreateIntrinsic(RetTy, Intrinsic::vector_partial_reduce_add,
                               {PhiVal, BinOpVal}, nullptr, "partial.reduce");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
index d8f1a86c9ebda..5b9bd0997f2fa 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
@@ -182,313 +182,233 @@ define i32 @dotp_predicated(i64 %N, ptr %a, ptr %b) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_LOAD_CONTINUE62:%.*]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <16 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[PRED_LOAD_CONTINUE62]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[PRED_LOAD_CONTINUE62]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
-; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
-; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
-; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
-; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
-; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
-; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
 ; CHECK-NEXT:    [[TMP16:%.*]] = icmp ule <16 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i1> [[TMP16]], i32 0
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP19:%.*]] = load i8, ptr [[TMP18]], align 1
 ; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <16 x i8> poison, i8 [[TMP19]], i32 0
+; CHECK-NEXT:    [[TMP99:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP101:%.*]] = load i8, ptr [[TMP99]], align 1
+; CHECK-NEXT:    [[TMP102:%.*]] = insertelement <16 x i8> poison, i8 [[TMP101]], i32 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
 ; CHECK:       pred.load.continue:
 ; CHECK-NEXT:    [[TMP21:%.*]] = phi <16 x i8> [ poison, [[VECTOR_BODY]] ], [ [[TMP20]], [[PRED_LOAD_IF]] ]
+; CHECK-NEXT:    [[TMP103:%.*]] = phi <16 x i8> [ poison, [[VECTOR_BODY]] ], [ [[TMP102]], [[PRED_LOAD_IF]] ]
 ; CHECK-NEXT:    [[TMP22:%.*]] = extractelement <16 x i1> [[TMP16]], i32 1
 ; CHECK-NEXT:    br i1 [[TMP22]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
 ; CHECK:       pred.load.if1:
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP1]]
 ; CHECK-NEXT:    [[TMP24:%.*]] = load i8, ptr [[TMP23]], align 1
 ; CHECK-NEXT:    [[TMP25:%.*]] = insertelement <16 x i8> [[TMP21]], i8 [[TMP24]], i32 1
+; CHECK-NEXT:    [[TMP104:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP105:%.*]] = load i8, ptr [[TMP104]], align 1
+; CHECK-NEXT:    [[TMP109:%.*]] = insertelement <16 x i8> [[TMP103]], i8 [[TMP105]], i32 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE2]]
 ; CHECK:       pred.load.continue2:
 ; CHECK-NEXT:    [[TMP26:%.*]] = phi <16 x i8> [ [[TMP21]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP25]], [[PRED_LOAD_IF1]] ]
+; CHECK-NEXT:    [[TMP111:%.*]] = phi <16 x i8> [ [[TMP103]], [[PRED_LOAD_CONTINUE]] ], [ [[TMP109]], [[PRED_LOAD_IF1]] ]
 ; CHECK-NEXT:    [[TMP27:%.*]] = extractelement <16 x i1> [[TMP16]], i32 2
 ; CHECK-NEXT:    br i1 [[TMP27]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
+; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
 ; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP29:%.*]] = load i8, ptr [[TMP28]], align 1
 ; CHECK-NEXT:    [[TMP30:%.*]] = insertelement <16 x i8> [[TMP26]], i8 [[TMP29]], i32 2
+; CHECK-NEXT:    [[TMP112:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP113:%.*]] = load i8, ptr [[TMP112]], align 1
+; CHECK-NEXT:    [[TMP114:%.*]] = insertelement <16 x i8> [[TMP111]], i8 [[TMP113]], i32 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
 ; CHECK:       pred.load.continue4:
 ; CHECK-NEXT:    [[TMP31:%.*]] = phi <16 x i8> [ [[TMP26]], [[PRED_LOAD_CONTINUE2]] ], [ [[TMP30]], [[PRED_LOAD_IF3]] ]
+; CHECK-NEXT:    [[TMP115:%.*]] = phi <16 x i8> [ [[TMP111]], [[PRED_LOAD_CONTINUE2]] ], [ [[TMP114]], [[PRED_LOAD_IF3]] ]
 ; CHECK-NEXT:    [[TMP32:%.*]] = extractelement <16 x i1> [[TMP16]], i32 3
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6:%.*]]
 ; CHECK:       pred.load.if5:
+; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
 ; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[TMP34:%.*]] = load i8, ptr [[TMP33]], align 1
 ; CHECK-NEXT:    [[TMP35:%.*]] = insertelement <16 x i8> [[TMP31]], i8 [[TMP34]], i32 3
+; CHECK-NEXT:    [[TMP119:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP121:%.*]] = load i8, ptr [[TMP119]], align 1
+; CHECK-NEXT:    [[TMP122:%.*]] = insertelement <16 x i8> [[TMP115]], i8 [[TMP121]], i32 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.continue6:
 ; CHECK-NEXT:    [[TMP36:%.*]] = phi <16 x i8> [ [[TMP31]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP35]], [[PRED_LOAD_IF5]] ]
+; CHECK-NEXT:    [[TMP123:%.*]] = phi <16 x i8> [ [[TMP115]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP122]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP37:%.*]] = extractelement <16 x i1> [[TMP16]], i32 4
 ; CHECK-NEXT:    br i1 [[TMP37]], label [[PRED_LOAD_IF7:%.*]], label [[PRED_LOAD_CONTINUE8:%.*]]
 ; CHECK:       pred.load.if7:
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP38:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP4]]
 ; CHECK-NEXT:    [[TMP39:%.*]] = load i8, ptr [[TMP38]], align 1
 ; CHECK-NEXT:    [[TMP40:%.*]] = insertelement <16 x i8> [[TMP36]], i8 [[TMP39]], i32 4
+; CHECK-NEXT:    [[TMP124:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP125:%.*]] = load i8, ptr [[TMP124]], align 1
+; CHECK-NEXT:    [[TMP129:%.*]] = insertelement <16 x i8> [[TMP123]], i8 [[TMP125]], i32 4
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE8]]
 ; CHECK:       pred.load.continue8:
 ; CHECK-NEXT:    [[TMP41:%.*]] = phi <16 x i8> [ [[TMP36]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP40]], [[PRED_LOAD_IF7]] ]
+; CHECK-NEXT:    [[TMP131:%.*]] = phi <16 x i8> [ [[TMP123]], [[PRED_LOAD_CONTINUE6]] ], [ [[TMP129]], [[PRED_LOAD_IF7]] ]
 ; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <16 x i1> [[TMP16]], i32 5
 ; CHECK-NEXT:    br i1 [[TMP42]], label [[PRED_LOAD_IF9:%.*]], label [[PRED_LOAD_CONTINUE10:%.*]]
 ; CHECK:       pred.load.if9:
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
 ; CHECK-NEXT:    [[TMP43:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP5]]
 ; CHECK-NEXT:    [[TMP44:%.*]] = load i8, ptr [[TMP43]], align 1
 ; CHECK-NEXT:    [[TMP45:%.*]] = insertelement <16 x i8> [[TMP41]], i8 [[TMP44]], i32 5
+; CHECK-NEXT:    [[TMP132:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP133:%.*]] = load i8, ptr [[TMP132]], align 1
+; CHECK-NEXT:    [[TMP134:%.*]] = insertelement <16 x i8> [[TMP131]], i8 [[TMP133]], i32 5
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE10]]
 ; CHECK:       pred.load.continue10:
 ; CHECK-NEXT:    [[TMP46:%.*]] = phi <16 x i8> [ [[TMP41]], [[PRED_LOAD_CONTINUE8]] ], [ [[TMP45]], [[PRED_LOAD_IF9]] ]
+; CHECK-NEXT:    [[TMP135:%.*]] = phi <16 x i8> [ [[TMP131]], [[PRED_LOAD_CONTINUE8]] ], [ [[TMP134]], [[PRED_LOAD_IF9]] ]
 ; CHECK-NEXT:    [[TMP47:%.*]] = extractelement <16 x i1> [[TMP16]], i32 6
 ; CHECK-NEXT:    br i1 [[TMP47]], label [[PRED_LOAD_IF11:%.*]], label [[PRED_LOAD_CONTINUE12:%.*]]
 ; CHECK:       pred.load.if11:
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
 ; CHECK-NEXT:    [[TMP48:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP6]]
 ; CHECK-NEXT:    [[TMP49:%.*]] = load i8, ptr [[TMP48]], align 1
 ; CHECK-NEXT:    [[TMP50:%.*]] = insertelement <16 x i8> [[TMP46]], i8 [[TMP49]], i32 6
+; CHECK-NEXT:    [[TMP139:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP141:%.*]] = load i8, ptr [[TMP139]], align 1
+; CHECK-NEXT:    [[TMP142:%.*]] = insertelement <16 x i8> [[TMP135]], i8 [[TMP141]], i32 6
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE12]]
 ; CHECK:       pred.load.continue12:
 ; CHECK-NEXT:    [[TMP51:%.*]] = phi <16 x i8> [ [[TMP46]], [[PRED_LOAD_CONTINUE10]] ], [ [[TMP50]], [[PRED_LOAD_IF11]] ]
+; CHECK-NEXT:    [[TMP143:%.*]] = phi <16 x i8> [ [[TMP135]], [[PRED_LOAD_CONTINUE10]] ], [ [[TMP142]], [[PRED_LOAD_IF11]] ]
 ; CHECK-NEXT:    [[TMP52:%.*]] = extractelement <16 x i1> [[TMP16]], i32 7
 ; CHECK-NEXT:    br i1 [[TMP52]], label [[PRED_LOAD_IF13:%.*]], label [[PRED_LOAD_CONTINUE14:%.*]]
 ; CHECK:       pred.load.if13:
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
 ; CHECK-NEXT:    [[TMP53:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP7]]
 ; CHECK-NEXT:    [[TMP54:%.*]] = load i8, ptr [[TMP53]], align 1
 ; CHECK-NEXT:    [[TMP55:%.*]] = insertelement <16 x i8> [[TMP51]], i8 [[TMP54]], i32 7
+; CHECK-NEXT:    [[TMP144:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP145:%.*]] = load i8, ptr [[TMP144]], align 1
+; CHECK-NEXT:    [[TMP149:%.*]] = insertelement <16 x i8> [[TMP143]], i8 [[TMP145]], i32 7
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE14]]
 ; CHECK:       pred.load.continue14:
 ; CHECK-NEXT:    [[TMP56:%.*]] = phi <16 x i8> [ [[TMP51]], [[PRED_LOAD_CONTINUE12]] ], [ [[TMP55]], [[PRED_LOAD_IF13]] ]
+; CHECK-NEXT:    [[TMP150:%.*]] = phi <16 x i8> [ [[TMP143]], [[PRED_LOAD_CONTINUE12]] ], [ [[TMP149]], [[PRED_LOAD_IF13]] ]
 ; CHECK-NEXT:    [[TMP57:%.*]] = extractelement <16 x i1> [[TMP16]], i32 8
 ; CHECK-NEXT:    br i1 [[TMP57]], label [[PRED_LOAD_IF15:%.*]], label [[PRED_LOAD_CONTINUE16:%.*]]
 ; CHECK:       pred.load.if15:
+; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP58:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP8]]
 ; CHECK-NEXT:    [[TMP59:%.*]] = load i8, ptr [[TMP58]], align 1
 ; CHECK-NEXT:    [[TMP60:%.*]] = insertelement <16 x i8> [[TMP56]], i8 [[TMP59]], i32 8
+; CHECK-NEXT:    [[TMP151:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP8]]
+; CHECK-NEXT:    [[TMP152:%.*]] = load i8, ptr [[TMP151]], align 1
+; CHECK-NEXT:    [[TMP153:%.*]] = insertelement <16 x i8> [[TMP150]], i8 [[TMP152]], i32 8
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE16]]
 ; CHECK:       pred.load.continue16:
 ; CHECK-NEXT:    [[TMP61:%.*]] = phi <16 x i8> [ [[TMP56]], [[PRED_LOAD_CONTINUE14]] ], [ [[TMP60]], [[PRED_LOAD_IF15]] ]
+; CHECK-NEXT:    [[TMP154:%.*]] = phi <16 x i8> [ [[TMP150]], [[PRED_LOAD_CONTINUE14]] ], [ [[TMP153]], [[PRED_LOAD_IF15]] ]
 ; CHECK-NEXT:    [[TMP62:%.*]] = extractelement <16 x i1> [[TMP16]], i32 9
 ; CHECK-NEXT:    br i1 [[TMP62]], label [[PRED_LOAD_IF17:%.*]], label [[PRED_LOAD_CONTINUE18:%.*]]
 ; CHECK:       pred.load.if17:
+; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
 ; CHECK-NEXT:    [[TMP63:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP64:%.*]] = load i8, ptr [[TMP63]], align 1
 ; CHECK-NEXT:    [[TMP65:%.*]] = insertelement <16 x i8> [[TMP61]], i8 [[TMP64]], i32 9
+; CHECK-NEXT:    [[TMP96:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP155:%.*]] = load i8, ptr [[TMP96]], align 1
+; CHECK-NEXT:    [[TMP98:%.*]] = insertelement <16 x i8> [[TMP154]], i8 [[TMP155]], i32 9
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE18]]
 ; CHECK:       pred.load.continue18:
 ; CHECK-NEXT:    [[TMP66:%.*]] = phi <16 x i8> [ [[TMP61]], [[PRED_LOAD_CONTINUE16]] ], [ [[TMP65]], [[PRED_LOAD_IF17]] ]
+; CHECK-NEXT:    [[TMP100:%.*]] = phi <16 x i8> [ [[TMP154]], [[PRED_LOAD_CONTINUE16]] ], [ [[TMP98]], [[PRED_LOAD_IF17]] ]
 ; CHECK-NEXT:    [[TMP67:%.*]] = extractelement <16 x i1> [[TMP16]], i32 10
 ; CHECK-NEXT:    br i1 [[TMP67]], label [[PRED_LOAD_IF19:%.*]], label [[PRED_LOAD_CONTINUE20:%.*]]
 ; CHECK:       pred.load.if19:
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
 ; CHECK-NEXT:    [[TMP68:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP10]]
 ; CHECK-NEXT:    [[TMP69:%.*]] = load i8, ptr [[TMP68]], align 1
 ; CHECK-NEXT:    [[TMP70:%.*]] = insertelement <16 x i8> [[TMP66]], i8 [[TMP69]], i32 10
+; CHECK-NEXT:    [[TMP106:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP10]]
+; CHECK-NEXT:    [[TMP107:%.*]] = load i8, ptr [[TMP106]], align 1
+; CHECK-NEXT:    [[TMP108:%.*]] = insertelement <16 x i8> [[TMP100]], i8 [[TMP107]], i32 10
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE20]]
 ; CHECK:       pred.load.continue20:
 ; CHECK-NEXT:    [[TMP71:%.*]] = phi <16 x i8> [ [[TMP66]], [[PRED_LOAD_CONTINUE18]] ], [ [[TMP70]], [[PRED_LOAD_IF19]] ]
+; CHECK-NEXT:    [[TMP110:%.*]] = phi <16 x i8> [ [[TMP100]], [[PRED_LOAD_CONTINUE18]] ], [ [[TMP108]], [[PRED_LOAD_IF19]] ]
 ; CHECK-NEXT:    [[TMP72:%.*]] = extractelement <16 x i1> [[TMP16]], i32 11
 ; CHECK-NEXT:    br i1 [[TMP72]], label [[PRED_LOAD_IF21:%.*]], label [[PRED_LOAD_CONTINUE22:%.*]]
 ; CHECK:       pred.load.if21:
+; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
 ; CHECK-NEXT:    [[TMP73:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP11]]
 ; CHECK-NEXT:    [[TMP74:%.*]] = load i8, ptr [[TMP73]], align 1
 ; CHECK-NEXT:    [[TMP75:%.*]] = insertelement <16 x i8> [[TMP71]], i8 [[TMP74]], i32 11
+; CHECK-NEXT:    [[TMP116:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP117:%.*]] = load i8, ptr [[TMP116]], align 1
+; CHECK-NEXT:    [[TMP118:%.*]] = insertelement <16 x i8> [[TMP110]], i8 [[TMP117]], i32 11
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE22]]
 ; CHECK:       pred.load.continue22:
 ; CHECK-NEXT:    [[TMP76:%.*]] = phi <16 x i8> [ [[TMP71]], [[PRED_LOAD_CONTINUE20]] ], [ [[TMP75]], [[PRED_LOAD_IF21]] ]
+; CHECK-NEXT:    [[TMP120:%.*]] = phi <16 x i8> [ [[TMP110]], [[PRED_LOAD_CONTINUE20]] ], [ [[TMP118]], [[PRED_LOAD_IF21]] ]
 ; CHECK-NEXT:    [[TMP77:%.*]] = extractelement <16 x i1> [[TMP16]], i32 12
 ; CHECK-NEXT:    br i1 [[TMP77]], label [[PRED_LOAD_IF23:%.*]], label [[PRED_LOAD_CONTINUE24:%.*]]
 ; CHECK:       pred.load.if23:
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
 ; CHECK-NEXT:    [[TMP78:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP12]]
 ; CHECK-NEXT:    [[TMP79:%.*]] = load i8, ptr [[TMP78]], align 1
 ; CHECK-NEXT:    [[TMP80:%.*]] = insertelement <16 x i8> [[TMP76]], i8 [[TMP79]], i32 12
+; CHECK-NEXT:    [[TMP126:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP12]]
+; CHECK-NEXT:    [[TMP127:%.*]] = load i8, ptr [[TMP126]], align 1
+; CHECK-NEXT:    [[TMP128:%.*]] = insertelement <16 x i8> [[TMP120]], i8 [[TMP127]], i32 12
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE24]]
 ; CHECK:       pred.load.continue24:
 ; CHECK-NEXT:    [[TMP81:%.*]] = phi <16 x i8> [ [[TMP76]], [[PRED_LOAD_CONTINUE22]] ], [ [[TMP80]], [[PRED_LOAD_IF23]] ]
+; CHECK-NEXT:    [[TMP130:%.*]] = phi <16 x i8> [ [[TMP120]], [[PRED_LOAD_CONTINUE22]] ], [ [[TMP128]], [[PRED_LOAD_IF23]] ]
 ; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <16 x i1> [[TMP16]], i32 13
 ; CHECK-NEXT:    br i1 [[TMP82]], label [[PRED_LOAD_IF25:%.*]], label [[PRED_LOAD_CONTINUE26:%.*]]
 ; CHECK:       pred.load.if25:
+; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
 ; CHECK-NEXT:    [[TMP83:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP13]]
 ; CHECK-NEXT:    [[TMP84:%.*]] = load i8, ptr [[TMP83]], align 1
 ; CHECK-NEXT:    [[TMP85:%.*]] = insertelement <16 x i8> [[TMP81]], i8 [[TMP84]], i32 13
+; CHECK-NEXT:    [[TMP136:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP13]]
+; CHECK-NEXT:    [[TMP137:%.*]] = load i8, ptr [[TMP136]], align 1
+; CHECK-NEXT:    [[TMP138:%.*]] = insertelement <16 x i8> [[TMP130]], i8 [[TMP137]], i32 13
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE26]]
 ; CHECK:       pred.load.continue26:
 ; CHECK-NEXT:    [[TMP86:%.*]] = phi <16 x i8> [ [[TMP81]], [[PRED_LOAD_CONTINUE24]] ], [ [[TMP85]], [[PRED_LOAD_IF25]] ]
+; CHECK-NEXT:    [[TMP140:%.*]] = phi <16 x i8> [ [[TMP130]], [[PRED_LOAD_CONTINUE24]] ], [ [[TMP138]], [[PRED_LOAD_IF25]] ]
 ; CHECK-NEXT:    [[TMP87:%.*]] = extractelement <16 x i1> [[TMP16]], i32 14
 ; CHECK-NEXT:    br i1 [[TMP87]], label [[PRED_LOAD_IF27:%.*]], label [[PRED_LOAD_CONTINUE28:%.*]]
 ; CHECK:       pred.load.if27:
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
 ; CHECK-NEXT:    [[TMP88:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP14]]
 ; CHECK-NEXT:    [[TMP89:%.*]] = load i8, ptr [[TMP88]], align 1
 ; CHECK-NEXT:    [[TMP90:%.*]] = insertelement <16 x i8> [[TMP86]], i8 [[TMP89]], i32 14
+; CHECK-NEXT:    [[TMP146:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP14]]
+; CHECK-NEXT:    [[TMP147:%.*]] = load i8, ptr [[TMP146]], align 1
+; CHECK-NEXT:    [[TMP148:%.*]] = insertelement <16 x i8> [[TMP140]], i8 [[TMP147]], i32 14
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE28]]
 ; CHECK:       pred.load.continue28:
 ; CHECK-NEXT:    [[TMP91:%.*]] = phi <16 x i8> [ [[TMP86]], [[PRED_LOAD_CONTINUE26]] ], [ [[TMP90]], [[PRED_LOAD_IF27]] ]
+; CHECK-NEXT:    [[TMP172:%.*]] = phi <16 x i8> [ [[TMP140]], [[PRED_LOAD_CONTINUE26]] ], [ [[TMP148]], [[PRED_LOAD_IF27]] ]
 ; CHECK-NEXT:    [[TMP92:%.*]] = extractelement <16 x i1> [[TMP16]], i32 ...
[truncated]


exit:
ret i32 %add
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description, you mention predication won't work for if/else blocks. Is it maybe worth adding a small test to document that?

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a no-brainer to me. I think we should now be able to remove the m_Select check in computeCost, if you wouldn't mind doing that in this PR as well:

if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) &&

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, with this change IIUC we now assume that the cost of partial reduction with and without selects will be the same. Is that correct in all cases?

Copy link
Contributor

@gbossu gbossu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This shows that no VPExpression is built for partial reductions
that have some form of predication.
This means that VPExpressions will now be constructed for
VPPartialReductionRecipe's when the loop has tail-folding predication.

Note that control-flow (if/else) predication is not yet handled
for partial reductions, because of the way partial reductions
are recognised and built up.
… reduction.

In practice this won't make much difference, because VPExpressions already accounts
for the cost of the predication.
@sdesmalen-arm sdesmalen-arm force-pushed the users/sdesmalen-arm/lv-move-partial-reduction-condition branch from 46b0e9c to 80e5448 Compare November 10, 2025 16:32
@sdesmalen-arm sdesmalen-arm merged commit 517d725 into main Nov 11, 2025
10 checks passed
@sdesmalen-arm sdesmalen-arm deleted the users/sdesmalen-arm/lv-move-partial-reduction-condition branch November 11, 2025 09:42
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, with this change IIUC we now assume that the cost of partial reduction with and without selects will be the same. Is that correct in all cases?

Not sure I missed the response to the comment, but it is still not entirely clear how the cost of the selects is properly accounted for. As far as I can tell, we now create a VPExpressionRecipe with a conditional VPPartialReductionRecipe, but dont account for the reduction being conditional neither when creating VPExpressionRecipe nor in VPExpressionRecipe::computeCost.

As a consequence, aren't we assuming a cost of 1 for VF = 16 for the whole ext/mul/select/reduced in the test case added in llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll, while it cannot be lowered to a single partial reduction instruction, with or without SVE? https://llvm.godbolt.org/z/ETas69vna

CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
auto *VecTy = Ctx.Types.inferScalarType(Op);
auto *CondTy = Ctx.Types.inferScalarType(getCondOp());
CondCost = Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VecTy, CondTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this computes the correct costs; isn't this passing the scalar types here instead of the vector types?

fhahn added a commit to fhahn/llvm-project that referenced this pull request Nov 11, 2025
Follow-up to llvm#166136.

As far as I can tell, currently we cannot lower partial reductions fed
by a select efficiently (https://llvm.godbolt.org/z/ETas69vna), so
skipping them for partial reduction cost-computations seems inaccurate.

Unless I am missing something, we can for now simply avoid forming
conditional partial reductions, as they won't be profitable and using
regular reductions exposes more parallelism.
@fhahn
Copy link
Contributor

fhahn commented Nov 11, 2025

Hmm, with this change IIUC we now assume that the cost of partial reduction with and without selects will be the same. Is that correct in all cases?

Not sure I missed the response to the comment, but it is still not entirely clear how the cost of the selects is properly accounted for. As far as I can tell, we now create a VPExpressionRecipe with a conditional VPPartialReductionRecipe, but dont account for the reduction being conditional neither when creating VPExpressionRecipe nor in VPExpressionRecipe::computeCost.

As a consequence, aren't we assuming a cost of 1 for VF = 16 for the whole ext/mul/select/reduced in the test case added in llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll, while it cannot be lowered to a single partial reduction instruction, with or without SVE? https://llvm.godbolt.org/z/ETas69vna

Looking a bit more into this, I couldn't find a case where forming predicated partial reduction is actually benefifical, so unless I missed soemthing, it might be simplest to just avoid forming them in such cases in the first place for now? #167506

fhahn added a commit to fhahn/llvm-project that referenced this pull request Nov 11, 2025
Follow-up to llvm#166136.

As far as I can tell, currently we cannot lower partial reductions fed
by a select efficiently (https://llvm.godbolt.org/z/ETas69vna), so
skipping them for partial reduction cost-computations seems inaccurate.

Unless I am missing something, we can for now simply avoid forming
conditional partial reductions, as they won't be profitable and using
regular reductions exposes more parallelism.
fhahn added a commit to fhahn/llvm-project that referenced this pull request Nov 12, 2025
Follow-up to llvm#166136.

As far as I can tell, currently we cannot lower partial reductions fed
by a select efficiently (https://llvm.godbolt.org/z/ETas69vna), so
skipping them for partial reduction cost-computations seems inaccurate.

Unless I am missing something, we can for now simply avoid forming
conditional partial reductions, as they won't be profitable and using
regular reductions exposes more parallelism.
fhahn added a commit to fhahn/llvm-project that referenced this pull request Nov 13, 2025
Follow-up to llvm#166136.

As far as I can tell, currently we cannot lower partial reductions fed
by a select efficiently (https://llvm.godbolt.org/z/ETas69vna), so
skipping them for partial reduction cost-computations seems inaccurate.

Unless I am missing something, we can for now simply avoid forming
conditional partial reductions, as they won't be profitable and using
regular reductions exposes more parallelism.
fhahn added a commit to fhahn/llvm-project that referenced this pull request Nov 13, 2025
Follow-up to llvm#166136.

As far as I can tell, currently we cannot lower partial reductions fed
by a select efficiently (https://llvm.godbolt.org/z/ETas69vna), so
skipping them for partial reduction cost-computations seems inaccurate.

Unless I am missing something, we can for now simply avoid forming
conditional partial reductions, as they won't be profitable and using
regular reductions exposes more parallelism.
return CondCost + Ctx.TTI.getPartialReductionCost(
getOpcode(), InputTypeA, InputTypeB, PhiType, VF,
ExtAType, ExtBType, Opcode, Ctx.CostKind);
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extraneous ; here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants