Skip to content

Conversation

@SamTebbs33
Copy link
Collaborator

@SamTebbs33 SamTebbs33 commented Oct 29, 2025

A pattern of the form reduce.add(ext(mul)) is valid for a partial reduction as long as the mul and its operands fulfill the requirements of a normal partial reduction. The mul's extend operands will be optimised to the wider extend, and we already have oneUse checks in place to make sure the mul and operands can be modified safely.

  1. -> [LV] Allow partial reductions with an extended bin op #165536
  2. [LV] Use assertion in VPExpressionRecipe creation #165543

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Sam Tebbs (SamTebbs33)

Changes

A pattern of the form reduce.add(ext(mul)) is valid for a partial reduction as long as the mul and its operands fulfill the requirements of a normal partial reduction. The mul's extend operands will be optimised to the wider extend, and we already have oneUse checks in place to make sure the mul and operands can be modified safely.


Patch is 28.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165536.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+9-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll (-80)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+125)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+144)
  • (modified) llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll (+90)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f7968abbe5b6b..f83fe82c2dfbe 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7946,6 +7946,15 @@ bool VPRecipeBuilder::getScaledReductions(
   if (Op == PHI)
     std::swap(Op, PhiOp);
 
+  using namespace llvm::PatternMatch;
+  // If Op is an extend, then it's still a valid partial reduction if the
+  // extended mul fulfills the other requirements.
+  // For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
+  // reduction since the inner extends will be widened. We already have oneUse
+  // checks on the inner extends so widening them is safe.
+  if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value()))))
+    Op = cast<Instruction>(Op)->getOperand(0);
+
   // Try and get a scaled reduction from the first non-phi operand.
   // If one is found, we use the discovered reduction instruction in
   // place of the accumulator for costing.
@@ -7962,8 +7971,6 @@ bool VPRecipeBuilder::getScaledReductions(
   if (PhiOp != PHI)
     return false;
 
-  using namespace llvm::PatternMatch;
-
   // If the update is a binary operator, check both of its operands to see if
   // they are extends. Otherwise, see if the update comes directly from an
   // extend.
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
index b430efc9e5283..b033f6051f812 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll
@@ -467,83 +467,3 @@ loop:
 exit:
   ret i32 %red.next
 }
-
-define i64 @partial_reduction_mul_two_users(i64 %n, ptr %a, i16 %b, i32 %c) {
-; CHECK-LABEL: define i64 @partial_reduction_mul_two_users(
-; CHECK-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i16 [[B:%.*]], i32 [[C:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:  [[ENTRY:.*]]:
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[N]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
-; CHECK:       [[VECTOR_PH]]:
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
-; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
-; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i16> poison, i16 [[B]], i64 0
-; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT]], <8 x i16> poison, <8 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT]] to <8 x i32>
-; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
-; CHECK:       [[VECTOR_BODY]]:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr [[A]], align 2
-; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <8 x i16> poison, i16 [[TMP4]], i64 0
-; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT1]], <8 x i16> poison, <8 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = zext <8 x i32> [[TMP2]] to <8 x i64>
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i64> @llvm.vector.partial.reduce.add.v4i64.v8i64(<4 x i64> [[VEC_PHI]], <8 x i64> [[TMP3]])
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT2]] to <8 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i32> [[TMP5]] to <8 x i64>
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP7]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
-; CHECK:       [[MIDDLE_BLOCK]]:
-; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[PARTIAL_REDUCE]])
-; CHECK-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP6]], i32 7
-; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
-; CHECK:       [[SCALAR_PH]]:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ [[VECTOR_RECUR_EXTRACT]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP8]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    br label %[[LOOP:.*]]
-; CHECK:       [[LOOP]]:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
-; CHECK-NEXT:    [[RES1:%.*]] = phi i64 [ [[SCALAR_RECUR_INIT]], %[[SCALAR_PH]] ], [ [[LOAD_EXT_EXT:%.*]], %[[LOOP]] ]
-; CHECK-NEXT:    [[RES2:%.*]] = phi i64 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD:%.*]], %[[LOOP]] ]
-; CHECK-NEXT:    [[LOAD:%.*]] = load i16, ptr [[A]], align 2
-; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
-; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[B]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[CONV]], [[CONV]]
-; CHECK-NEXT:    [[MUL_EXT:%.*]] = zext i32 [[MUL]] to i64
-; CHECK-NEXT:    [[ADD]] = add i64 [[RES2]], [[MUL_EXT]]
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[MUL]], [[C]]
-; CHECK-NEXT:    [[LOAD_EXT:%.*]] = sext i16 [[LOAD]] to i32
-; CHECK-NEXT:    [[LOAD_EXT_EXT]] = sext i32 [[LOAD_EXT]] to i64
-; CHECK-NEXT:    [[EXITCOND740_NOT:%.*]] = icmp eq i64 [[IV]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND740_NOT]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP19:![0-9]+]]
-; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    [[ADD_LCSSA:%.*]] = phi i64 [ [[ADD]], %[[LOOP]] ], [ [[TMP8]], %[[MIDDLE_BLOCK]] ]
-; CHECK-NEXT:    ret i64 [[ADD_LCSSA]]
-;
-entry:
-  br label %loop
-
-loop:
-  %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
-  %res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ]
-  %res2 = phi i64 [ 0, %entry ], [ %add, %loop ]
-  %load = load i16, ptr %a, align 2
-  %iv.next = add i64 %iv, 1
-  %conv = sext i16 %b to i32
-  %mul = mul i32 %conv, %conv
-  %mul.ext = zext i32 %mul to i64
-  %add = add i64 %res2, %mul.ext
-  %second_use = or i32 %mul, %c ; this value is otherwise unused, but that's sufficient for the test
-  %load.ext = sext i16 %load to i32
-  %load.ext.ext = sext i32 %load.ext to i64
-  %exitcond740.not = icmp eq i64 %iv, %n
-  br i1 %exitcond740.not, label %exit, label %loop
-
-exit:
-  ret i64 %add
-}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll
index b84763142b686..2ad0bb350392b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll
@@ -2123,6 +2123,131 @@ for.exit:                        ; preds = %for.body
   ret i32 %result
 }
 
+define i32 @partial_reduction_ext_mul(i64 %n, ptr %a, i8 %b) {
+; CHECK-INTERLEAVE1-LABEL: define i32 @partial_reduction_ext_mul(
+; CHECK-INTERLEAVE1-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i8 [[B:%.*]]) #[[ATTR0]] {
+; CHECK-INTERLEAVE1-NEXT:  entry:
+; CHECK-INTERLEAVE1-NEXT:    [[TMP0:%.*]] = add i64 [[N]], 1
+; CHECK-INTERLEAVE1-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
+; CHECK-INTERLEAVE1-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-INTERLEAVE1:       vector.ph:
+; CHECK-INTERLEAVE1-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
+; CHECK-INTERLEAVE1-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-INTERLEAVE1-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i8> poison, i8 [[B]], i64 0
+; CHECK-INTERLEAVE1-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i8> [[BROADCAST_SPLATINSERT]], <8 x i8> poison, <8 x i32> zeroinitializer
+; CHECK-INTERLEAVE1-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-INTERLEAVE1:       vector.body:
+; CHECK-INTERLEAVE1-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVE1-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVE1-NEXT:    [[TMP1:%.*]] = load i16, ptr [[A]], align 2
+; CHECK-INTERLEAVE1-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <8 x i16> poison, i16 [[TMP1]], i64 0
+; CHECK-INTERLEAVE1-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT1]], <8 x i16> poison, <8 x i32> zeroinitializer
+; CHECK-INTERLEAVE1-NEXT:    [[TMP2:%.*]] = zext <8 x i8> [[BROADCAST_SPLAT]] to <8 x i32>
+; CHECK-INTERLEAVE1-NEXT:    [[TMP3:%.*]] = mul <8 x i32> [[TMP2]], [[TMP2]]
+; CHECK-INTERLEAVE1-NEXT:    [[PARTIAL_REDUCE]] = call <2 x i32> @llvm.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> [[VEC_PHI]], <8 x i32> [[TMP3]])
+; CHECK-INTERLEAVE1-NEXT:    [[TMP4:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT2]] to <8 x i32>
+; CHECK-INTERLEAVE1-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-INTERLEAVE1-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-INTERLEAVE1-NEXT:    [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-INTERLEAVE1-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]]
+; CHECK-INTERLEAVE1:       middle.block:
+; CHECK-INTERLEAVE1-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[PARTIAL_REDUCE]])
+; CHECK-INTERLEAVE1-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP5]], i32 7
+; CHECK-INTERLEAVE1-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
+; CHECK-INTERLEAVE1-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-INTERLEAVE1:       scalar.ph:
+;
+; CHECK-INTERLEAVED-LABEL: define i32 @partial_reduction_ext_mul(
+; CHECK-INTERLEAVED-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i8 [[B:%.*]]) #[[ATTR0]] {
+; CHECK-INTERLEAVED-NEXT:  entry:
+; CHECK-INTERLEAVED-NEXT:    [[TMP0:%.*]] = add i64 [[N]], 1
+; CHECK-INTERLEAVED-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 16
+; CHECK-INTERLEAVED-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-INTERLEAVED:       vector.ph:
+; CHECK-INTERLEAVED-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 16
+; CHECK-INTERLEAVED-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-INTERLEAVED-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i8> poison, i8 [[B]], i64 0
+; CHECK-INTERLEAVED-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i8> [[BROADCAST_SPLATINSERT]], <8 x i8> poison, <8 x i32> zeroinitializer
+; CHECK-INTERLEAVED-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-INTERLEAVED:       vector.body:
+; CHECK-INTERLEAVED-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[VEC_PHI1:%.*]] = phi <2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE2:%.*]], [[VECTOR_BODY]] ]
+; CHECK-INTERLEAVED-NEXT:    [[TMP1:%.*]] = load i16, ptr [[A]], align 2
+; CHECK-INTERLEAVED-NEXT:    [[BROADCAST_SPLATINSERT3:%.*]] = insertelement <8 x i16> poison, i16 [[TMP1]], i64 0
+; CHECK-INTERLEAVED-NEXT:    [[BROADCAST_SPLAT4:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT3]], <8 x i16> poison, <8 x i32> zeroinitializer
+; CHECK-INTERLEAVED-NEXT:    [[TMP2:%.*]] = zext <8 x i8> [[BROADCAST_SPLAT]] to <8 x i32>
+; CHECK-INTERLEAVED-NEXT:    [[TMP3:%.*]] = mul <8 x i32> [[TMP2]], [[TMP2]]
+; CHECK-INTERLEAVED-NEXT:    [[PARTIAL_REDUCE]] = call <2 x i32> @llvm.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> [[VEC_PHI]], <8 x i32> [[TMP3]])
+; CHECK-INTERLEAVED-NEXT:    [[PARTIAL_REDUCE2]] = call <2 x i32> @llvm.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> [[VEC_PHI1]], <8 x i32> [[TMP3]])
+; CHECK-INTERLEAVED-NEXT:    [[TMP4:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT4]] to <8 x i32>
+; CHECK-INTERLEAVED-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-INTERLEAVED-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-INTERLEAVED-NEXT:    [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-INTERLEAVED-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]]
+; CHECK-INTERLEAVED:       middle.block:
+; CHECK-INTERLEAVED-NEXT:    [[BIN_RDX:%.*]] = add <2 x i32> [[PARTIAL_REDUCE2]], [[PARTIAL_REDUCE]]
+; CHECK-INTERLEAVED-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[BIN_RDX]])
+; CHECK-INTERLEAVED-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP5]], i32 7
+; CHECK-INTERLEAVED-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
+; CHECK-INTERLEAVED-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-INTERLEAVED:       scalar.ph:
+;
+; CHECK-MAXBW-LABEL: define i32 @partial_reduction_ext_mul(
+; CHECK-MAXBW-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i8 [[B:%.*]]) #[[ATTR0]] {
+; CHECK-MAXBW-NEXT:  entry:
+; CHECK-MAXBW-NEXT:    [[TMP0:%.*]] = add i64 [[N]], 1
+; CHECK-MAXBW-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
+; CHECK-MAXBW-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-MAXBW:       vector.ph:
+; CHECK-MAXBW-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
+; CHECK-MAXBW-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
+; CHECK-MAXBW-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i8> poison, i8 [[B]], i64 0
+; CHECK-MAXBW-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i8> [[BROADCAST_SPLATINSERT]], <8 x i8> poison, <8 x i32> zeroinitializer
+; CHECK-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-MAXBW:       vector.body:
+; CHECK-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <2 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
+; CHECK-MAXBW-NEXT:    [[TMP1:%.*]] = load i16, ptr [[A]], align 2
+; CHECK-MAXBW-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <8 x i16> poison, i16 [[TMP1]], i64 0
+; CHECK-MAXBW-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT1]], <8 x i16> poison, <8 x i32> zeroinitializer
+; CHECK-MAXBW-NEXT:    [[TMP2:%.*]] = zext <8 x i8> [[BROADCAST_SPLAT]] to <8 x i32>
+; CHECK-MAXBW-NEXT:    [[TMP3:%.*]] = mul <8 x i32> [[TMP2]], [[TMP2]]
+; CHECK-MAXBW-NEXT:    [[PARTIAL_REDUCE]] = call <2 x i32> @llvm.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> [[VEC_PHI]], <8 x i32> [[TMP3]])
+; CHECK-MAXBW-NEXT:    [[TMP4:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT2]] to <8 x i32>
+; CHECK-MAXBW-NEXT:    [[TMP5:%.*]] = sext <8 x i32> [[TMP4]] to <8 x i64>
+; CHECK-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-MAXBW-NEXT:    [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-MAXBW-NEXT:    br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]]
+; CHECK-MAXBW:       middle.block:
+; CHECK-MAXBW-NEXT:    [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[PARTIAL_REDUCE]])
+; CHECK-MAXBW-NEXT:    [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP5]], i32 7
+; CHECK-MAXBW-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
+; CHECK-MAXBW-NEXT:    br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-MAXBW:       scalar.ph:
+;
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
+  %res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ]
+  %res2 = phi i32 [ 0, %entry ], [ %add, %loop ]
+  %load = load i16, ptr %a, align 2
+  %iv.next = add i64 %iv, 1
+  %conv = zext i8 %b to i16
+  %mul = mul i16 %conv, %conv
+  %mul.ext = zext i16 %mul to i32
+  %add = add i32 %res2, %mul.ext
+  %load.ext = sext i16 %load to i32
+  %load.ext.ext = sext i32 %load.ext to i64
+  %exitcond740.not = icmp eq i64 %iv, %n
+  br i1 %exitcond740.not, label %exit, label %loop
+
+exit:
+  ret i32 %add
+}
+
 !7 = distinct !{!7, !8, !9, !10}
 !8 = !{!"llvm.loop.mustprogress"}
 !9 = !{!"llvm.loop.vectorize.predicate.enable", i1 true}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 49f663f5703b6..3ad357974d992 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -146,3 +146,147 @@ for.body:                                         ; preds = %for.body, %entry
 exit:
   ret i32 %add
 }
+
+define i32 @print_partial_reduction_ext_mul(i64 %n, ptr %a, i8 %b) {
+; CHECK:       VPlan 'Initial VPlan for VF={8},UF>=1' {
+; CHECK-NEXT:  Live-in vp<%0> = VF * UF
+; CHECK-NEXT:  Live-in vp<%1> = vector-trip-count
+; CHECK-NEXT:  vp<%2> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT:  ir-bb<entry>:
+; CHECK-NEXT:    EMIT vp<%2> = EXPAND SCEV (1 + %n)
+; CHECK-NEXT:  Successor(s): scalar.ph, vector.ph
+; CHECK-EMPTY:
+; CHECK-NEXT:  vector.ph:
+; CHECK-NEXT:    EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<4>
+; CHECK-NEXT:  Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  <x1> vector loop: {
+; CHECK-NEXT:    vector.body:
+; CHECK-NEXT:      EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:      WIDEN-REDUCTION-PHI ir<%res2> = phi vp<%3>, vp<%5> (VF scaled by 1/4)
+; CHECK-NEXT:      CLONE ir<%load> = load ir<%a>
+; CHECK-NEXT:      EXPRESSION vp<%5> = ir<%res2> + partial.reduce.add (mul (ir<%b> zext to i32), (ir<%b> zext to i32))
+; CHECK-NEXT:      WIDEN-CAST ir<%load.ext> = sext ir<%load> to i32
+; CHECK-NEXT:      WIDEN-CAST ir<%load.ext.ext> = sext ir<%load.ext> to i64
+; CHECK-NEXT:      EMIT vp<%index.next> = add nuw vp<%4>, vp<%0>
+; CHECK-NEXT:      EMIT branch-on-count vp<%index.next>, vp<%1>
+; CHECK-NEXT:    No successors
+; CHECK-NEXT:  }
+; CHECK-NEXT:  Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT:  middle.block:
+; CHECK-NEXT:    EMIT vp<%7> = compute-reduction-result ir<%res2>, vp<%5>
+; CHECK-NEXT:    EMIT vp<%vector.recur.extract> = extract-last-element ir<%load.ext.ext>
+; CHECK-NEXT:    EMIT vp<%cmp.n> = icmp eq vp<%2>, vp<%1>
+; CHECK-NEXT:    EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT:  Successor(s): ir-bb<exit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT:  ir-bb<exit>:
+; CHECK-NEXT:    IR   %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%7> from middle.block)
+; CHECK-NEXT:  No successors
+; CHECK-EMPTY:
+; CHECK-NEXT:  scalar.ph:
+; CHECK-NEXT:    EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%1>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:    EMIT-SCALAR vp<%scalar.recur.init> = phi [ vp<%vector.recur.extract>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:    EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%7>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:  Successor(s): ir-bb<loop>
+; CHECK-EMPTY:
+; CHECK-NEXT:  ir-bb<loop>:
+; CHECK-NEXT:    IR   %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT:    IR   %res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ] (extra operand: vp<%scalar.recur.init> from scalar.ph)
+; CHECK-NEXT:    IR   %res2 = phi i32 [ 0, %entry ], [ %add, %loop ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT:    IR   %load = load i16, ptr %a, align 2
+; CHECK-NEXT:    IR   %iv.next = add i64 %iv, 1
+; CHECK-NEXT:    IR   %conv = zext i8 %b to i16
+; CHECK-NEXT:    IR   %mul = mul i16 %conv, %conv
+; CHECK-NEXT:    IR   %mul.ext = zext i16 %mul to i32
+; CHECK-NEXT:    IR   %add = add i32 %res2, %mul.ext
+; CHECK-NEXT:    IR   %load.ext = sext i16 %load to i32
+; CHECK-NEXT:    IR   %load.ext.ext = sext i32 %load.ext to i64
+; CHECK-NEXT:    IR   %exitcond740.not = icmp eq i64 %iv, %n
+; CHECK-NEXT:  No successors
+; CHECK-NEXT:  }
+
+; CHECK:  VPlan 'Final VPlan for VF={8},UF={1...
[truncated]

ret i32 %red.next
}

define i64 @partial_reduction_mul_two_users(i64 %n, ptr %a, i16 %b, i32 %c) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to remove this test because it no longer tests what we want it to test. It's been replaced by a print test that makes sure that an ExtendedReduction is created instead of an ExtendedMulAccReduction.

// For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
// reduction since the inner extends will be widened. We already have oneUse
// checks on the inner extends so widening them is safe.
if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value()))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check that the operands of the mul are also extends?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to check what the operands are since the later checks in this function do that, but that does make me realise that we need to make sure the outer extend is compatible with the inner extends (i.e. the outer is sext or the same as the inner extends).

Copy link
Collaborator Author

@SamTebbs33 SamTebbs33 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've added some tests to cover each extend permutation and an i8 to i32 to i64 test to show that the rest of the function does still check the types correctly, even considering the outer extend.

Comment on lines 185 to 203
; CHECK-NEXT: ir-bb<exit>:
; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%7> from middle.block)
; CHECK-NEXT: No successors
; CHECK-EMPTY:
; CHECK-NEXT: scalar.ph:
; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%1>, middle.block ], [ ir<0>, ir-bb<entry> ]
; CHECK-NEXT: EMIT-SCALAR vp<%scalar.recur.init> = phi [ vp<%vector.recur.extract>, middle.block ], [ ir<0>, ir-bb<entry> ]
; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%7>, middle.block ], [ ir<0>, ir-bb<entry> ]
; CHECK-NEXT: Successor(s): ir-bb<loop>
; CHECK-EMPTY:
; CHECK-NEXT: ir-bb<loop>:
; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] (extra operand: vp<%bc.resume.val> from scalar.ph)
; CHECK-NEXT: IR %res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ] (extra operand: vp<%scalar.recur.init> from scalar.ph)
; CHECK-NEXT: IR %res2 = phi i32 [ 0, %entry ], [ %add, %loop ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
; CHECK-NEXT: IR %load = load i16, ptr %a, align 2
; CHECK-NEXT: IR %iv.next = add i64 %iv, 1
; CHECK-NEXT: IR %conv = zext i8 %b to i16
; CHECK-NEXT: IR %mul = mul i16 %conv, %conv
; CHECK-NEXT: IR %mul.ext = zext i16 %mul to i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can strip this, and only check the VPlan with the VPExpressionRecipe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

; CHECK-NEXT: vp<%2> = original trip-count
; CHECK-EMPTY:
; CHECK-NEXT: ir-bb<entry>:
; CHECK-NEXT: EMIT vp<%2> = EXPAND SCEV (1 + %n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please capture the unnamed VPValues to make the test more robust w.r.t. future changes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always forget that, thanks.

ret i32 %result
}

define i32 @dotp_ext_mul(i64 %n, ptr %a, i8 %b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might make sense to split up the file, as it s getting quite big. We probably also do not need to check for all possible combinations for all tests, especially the negative ones

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think splitting this up is out of scope for the PR, since it would involve touching other tests that aren't related to this work.

We probably also do not need to check for all possible combinations for all tests, especially the negative ones.

Are there any specific test functions in here that you think are unnecessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all test functions are relevant, but we don't need to check INTERLEAVE1, INTERLEAVED and MAXBW for all of them I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair enough 👍 Are you okay with that being done separately?

%add = add i32 %res2, %mul.ext
%load.ext = sext i16 %load to i32
%load.ext.ext = sext i32 %load.ext to i64
%exitcond740.not = icmp eq i64 %iv, %n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot

Suggested change
%exitcond740.not = icmp eq i64 %iv, %n
%ec = icmp eq i64 %iv, %n

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ret i32 %result
}

define i32 @dotp_ext_mul(i64 %n, ptr %a, i8 %b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all test functions are relevant, but we don't need to check INTERLEAVE1, INTERLEAVED and MAXBW for all of them I think.

Comment on lines 8008 to 8009
// Make sure that the outer extend is either sext or the same kind as the
// inner extend.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This comment says what the code does, it'd be better if it said why other cases are not allowed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, hopefully that's better.

// inner extend.
if (OuterExtKind.has_value()) {
TTI::PartialReductionExtendKind OuterKind = OuterExtKind.value();
if (OuterKind != TTI::PartialReductionExtendKind::PR_SignExtend &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why the outer being a sign extend would always be okay? Wouldn't that be the same as: https://alive2.llvm.org/ce/z/Te3kiL

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right! Fixed.

; CHECK-NEXT: CLONE ir<%load> = load ir<%a>
; CHECK-NEXT: WIDEN-CAST ir<%load.ext> = sext ir<%load> to i32
; CHECK-NEXT: WIDEN-CAST ir<%load.ext.ext> = sext ir<%load.ext> to i64
; CHECK-NEXT: EXPRESSION vp<%5> = ir<%res2> + reduce.add (ir<%mul> zext to i64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems odd to be creating extend-accumulate expression recipes for values which are loop invariant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-worked it to not be loop invariant now.

%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
%res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ]
%res2 = phi i32 [ 0, %entry ], [ %add, %loop ]
%load = load i16, ptr %a, align 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the load needed, given that res1 isn't used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that the reduction wasn't recognised without res1, load.ext.ext and load.ext, but I've modeled it after the other printing function here and it's looking better now.

ret i32 %add
}

define i32 @not_dotp_zext_mul_sext(i64 %n, ptr %a, i8 %b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed it, but could you make sure we also have tests where the operands have differnet extend kinds and the mul is also extended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was convinced they existed but it seems not, added now.

Comment on lines +188 to +212
; CHECK-NEXT: ir-bb<exit>:
; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<[[RED_RESULT]]> from middle.block)
; CHECK-NEXT: No successors
; CHECK-EMPTY:
; CHECK-NEXT: scalar.ph:
; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<[[VEC_TC]]>, middle.block ], [ ir<0>, ir-bb<entry> ]
; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<[[RED_RESULT]]>, middle.block ], [ ir<0>, ir-bb<entry> ]
; CHECK-NEXT: Successor(s): ir-bb<for.body>
; CHECK-EMPTY:
; CHECK-NEXT: ir-bb<for.body>:
; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
; CHECK-NEXT: IR %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
; CHECK-NEXT: IR %gep.a = getelementptr i8, ptr %a, i64 %iv
; CHECK-NEXT: IR %load.a = load i8, ptr %gep.a, align 1
; CHECK-NEXT: IR %ext.a = zext i8 %load.a to i16
; CHECK-NEXT: IR %gep.b = getelementptr i8, ptr %b, i64 %iv
; CHECK-NEXT: IR %load.b = load i8, ptr %gep.b, align 1
; CHECK-NEXT: IR %ext.b = zext i8 %load.b to i16
; CHECK-NEXT: IR %mul = mul i16 %ext.b, %ext.a
; CHECK-NEXT: IR %mul.ext = zext i16 %mul to i32
; CHECK-NEXT: IR %add = add i32 %mul.ext, %accum
; CHECK-NEXT: IR %iv.next = add i64 %iv, 1
; CHECK-NEXT: IR %exitcond.not = icmp eq i64 %iv.next, 1024
; CHECK-NEXT: No successors
; CHECK-NEXT: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can drop those checks, as the test is mainly interested in checking VPExpressionRecipes are created as expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

A pattern of the form reduce.add(ext(mul)) is valid for a partial
reduction as long as the mul and its operands fulfill the requirements
of a normal partial reduction. The mul's extend operands will be
optimised to the wider extend, and we already have oneUse checks in
place to make sure the mul and operands can be modified safely.
@SamTebbs33 SamTebbs33 force-pushed the partialred-extended-binop branch from 5bb0c1d to f5ae82f Compare November 19, 2025 11:43
@github-actions
Copy link

🐧 Linux x64 Test Results

  • 186372 tests passed
  • 4855 tests skipped

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants