Skip to content

Conversation

sdesmalen-arm
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm commented Sep 15, 2025

This cost-model takes into account any type-legalisation that would happen on vectors such as splitting and promotion. This results in wider VFs being chosen for loops that can use partial reductions.

The cost-model now also assumes that when SVE is available, the SVE dot instructions for i16 -> i64 dot products can be used for fixed-length vectors. In practice this means that loops with non-scalable VFs are vectorized using partial reductions where they wouldn't before, e.g.

  int64_t foo2(int8_t *src1, int8_t *src2, int N) {
    int64_t sum = 0;
    for (int i=0; i<N; ++i)
      sum += (int64_t)src1[i] * (int64_t)src2[i];
    return sum;
  }

These changes also fix an issue where previously a partial reduction would be used for mixed sign/zero-extends (USDOT), even when +i8mm was not available.

@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

This cost-model takes into account any type-legalisation that would happen on vectors such as splitting and promotion. This results in wider VFs being chosen for loops that can use partial reductions.

The cost-model now also assumes that when SVE is available, the SVE dot instructions for i16 -> i64 dot products can be used for fixed-length vectors. In practice this means that loops with non-scalable VFs are vectorized using partial reductions where they wouldn't before, e.g.

  int64_t foo2(int8_t *src1, int8_t *src2, int N) {
    int64_t sum = 0;
    for (int i=0; i&lt;N; ++i)
      sum += (int64_t)src1[i] * (int64_t)src2[i];
    return sum;
  }

These changes also fix an issue where previously a partial reduction would be used for mixed sign/zero-extends (USDOT), even when +i8mm was not available.


Patch is 154.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158641.diff

10 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+62-49)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+82-106)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+10-10)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll (+50-50)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+226-336)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-sub.ll (+37-51)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll (+47-47)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/reg-usage.ll (+1-1)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+2-2)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8c4b4f6e4d6de..a2ee509fe4cb7 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5632,75 +5632,88 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
     TTI::TargetCostKind CostKind) const {
   InstructionCost Invalid = InstructionCost::getInvalid();
-  InstructionCost Cost(TTI::TCC_Basic);
 
   if (CostKind != TTI::TCK_RecipThroughput)
     return Invalid;
 
-  // Sub opcodes currently only occur in chained cases.
-  // Independent partial reduction subtractions are still costed as an add
+  if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+    return Invalid;
+
+  if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
+      (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
       OpAExtend == TTI::PR_None)
     return Invalid;
 
   // We only support multiply binary operations for now, and for muls we
   // require the types being extended to be the same.
-  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
-  // only if the i8mm or sve/streaming features are available.
-  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
-                OpBExtend == TTI::PR_None ||
-                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-                 !ST->isSVEorStreamingSVEAvailable())))
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
     return Invalid;
   assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
          "Unexpected values for OpBExtend or InputTypeB");
 
-  EVT InputEVT = EVT::getEVT(InputTypeA);
-  EVT AccumEVT = EVT::getEVT(AccumType);
+  bool IsUSDot = OpBExtend && OpAExtend != OpBExtend;
+  if (IsUSDot && !ST->hasMatMulInt8())
+    return Invalid;
 
-  unsigned VFMinValue = VF.getKnownMinValue();
+  unsigned Ratio =
+      AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+  if (VF.getKnownMinValue() < Ratio)
+    return Invalid;
 
-  if (VF.isScalable()) {
-    if (!ST->isSVEorStreamingSVEAvailable())
-      return Invalid;
+  VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
+  VectorType *AccumVectorType =
+      VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
 
-    // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
-    // since we can't lower that type.
-    unsigned Scale =
-        AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
-    if (VFMinValue == Scale)
-      return Invalid;
-  }
-  if (VF.isFixed() &&
-      (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
+  // We don't yet support widening for <vscale x 1 x ..> accumulators.
+  if (AccumVectorType->getElementCount() == ElementCount::getScalable(1))
     return Invalid;
 
-  if (InputEVT == MVT::i8) {
-    switch (VFMinValue) {
-    default:
-      return Invalid;
-    case 8:
-      if (AccumEVT == MVT::i32)
-        Cost *= 2;
-      else if (AccumEVT != MVT::i64)
-        return Invalid;
-      break;
-    case 16:
-      if (AccumEVT == MVT::i64)
-        Cost *= 2;
-      else if (AccumEVT != MVT::i32)
-        return Invalid;
-      break;
-    }
-  } else if (InputEVT == MVT::i16) {
-    // FIXME: Allow i32 accumulator but increase cost, as we would extend
-    //        it to i64.
-    if (VFMinValue != 8 || AccumEVT != MVT::i64)
-      return Invalid;
-  } else
-    return Invalid;
+  // Check what kind of type-legalisation happens.
+  std::pair<InstructionCost, MVT> AccumLT =
+      getTypeLegalizationCost(AccumVectorType);
+  std::pair<InstructionCost, MVT> InputLT =
+      getTypeLegalizationCost(InputVectorType);
 
-  return Cost;
+  InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
+
+  // Prefer using full types by costing half-full input types as more expensive.
+  if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
+                          TypeSize::getScalable(128)))
+    // FIXME: This can be removed after the cost of the extends are folded into
+    // the dot-product expression in VPlan, after landing:
+    //  https://github.com/llvm/llvm-project/pull/147302
+    Cost *= 2;
+
+  if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
+    // i16 -> i64 is natively supported for udot/sdot
+    if (AccumLT.second.getScalarType() == MVT::i64 &&
+        InputLT.second.getScalarType() == MVT::i16)
+      return Cost;
+    // i8 -> i64 is supported with an extra level of extends
+    if (AccumLT.second.getScalarType() == MVT::i64 &&
+        InputLT.second.getScalarType() == MVT::i8)
+      // FIXME: This cost should probably be a little higher, e.g. Cost + 2
+      // because it requires two extra extends on the inputs. But if we'd change
+      // that now, a regular reduction would be cheaper because the costs of
+      // the extends in the IR are still counted. This can be fixed
+      // after https://github.com/llvm/llvm-project/pull/147302 has landed.
+      return Cost;
+  }
+
+  // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
+  if (ST->isSVEorStreamingSVEAvailable() ||
+      (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
+       ST->hasDotProd())) {
+    if (AccumLT.second.getScalarType() == MVT::i32 &&
+        InputLT.second.getScalarType() == MVT::i8)
+      return Cost;
+  }
+
+  // Add additional cost for the extends that would need to be inserted.
+  return Cost + 4;
 }
 
 InstructionCost
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll b/llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll
index c3b0bc8c00a74..47f43339dc35f 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll
@@ -82,11 +82,11 @@ define i64 @test_two_ivs(ptr %a, ptr %b, i64 %start) #0 {
 ; CHECK-NEXT: Cost of 0 for VF 8: induction instruction   %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
 ; CHECK-NEXT: Cost of 1 for VF 8: exit condition instruction   %exitcond.not = icmp eq i64 %i.iv.next, 16
 ; CHECK-NEXT: Cost of 0 for VF 8: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
-; CHECK: Cost for VF 8: 27
+; CHECK: Cost for VF 8: 25
 ; CHECK-NEXT: Cost of 0 for VF 16: induction instruction   %i.iv = phi i64 [ 0, %entry ], [ %i.iv.next, %for.body ]
 ; CHECK-NEXT: Cost of 0 for VF 16: induction instruction   %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
 ; CHECK-NEXT: Cost of 0 for VF 16: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
-; CHECK: Cost for VF 16: 48
+; CHECK: Cost for VF 16: 41
 ; CHECK: LV: Selecting VF: 16
 entry:
   br label %for.body
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index c0995ec150c8d..e40f6f3647e52 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -204,37 +204,33 @@ define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
 ; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
 ; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
 ; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
-; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
-; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-SVE:       vector.ph:
-; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
-; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
 ; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE:       vector.body:
 ; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
-; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP9]], align 1
-; CHECK-SVE-NEXT:    [[TMP13:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP14:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP15:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-NEXT:    [[TMP17:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP16]]
-; CHECK-SVE-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-NEXT:    [[TMP19]] = add <vscale x 4 x i32> [[TMP17]], [[TMP18]]
-; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP9]], align 1
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP6:%.*]] = mul nsw <16 x i32> [[TMP3]], [[TMP4]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP6]])
+; CHECK-SVE-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP3]], [[TMP5]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP10]])
+; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-SVE-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-SVE-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK-SVE:       middle.block:
-; CHECK-SVE-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP19]])
+; CHECK-SVE-NEXT:    [[TMP11:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
 ; CHECK-SVE-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-SVE-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ; CHECK-SVE:       scalar.ph:
@@ -670,39 +666,35 @@ define i32 @chained_partial_reduce_add_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #
 ; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
 ; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
 ; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
-; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
-; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-SVE:       vector.ph:
-; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
-; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
 ; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE:       vector.body:
 ; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP21:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE4:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
-; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP9]], align 1
-; CHECK-SVE-NEXT:    [[TMP13:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP14:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP15:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-NEXT:    [[TMP17:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP16]]
-; CHECK-SVE-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP15]]
-; CHECK-SVE-NEXT:    [[TMP19:%.*]] = add <vscale x 4 x i32> [[TMP17]], [[TMP18]]
-; CHECK-SVE-NEXT:    [[TMP20:%.*]] = mul nsw <vscale x 4 x i32> [[TMP14]], [[TMP15]]
-; CHECK-SVE-NEXT:    [[TMP21]] = add <vscale x 4 x i32> [[TMP19]], [[TMP20]]
-; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP9]], align 1
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP6:%.*]] = mul nsw <16 x i32> [[TMP3]], [[TMP4]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP6]])
+; CHECK-SVE-NEXT:    [[TMP11:%.*]] = mul nsw <16 x i32> [[TMP3]], [[TMP5]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE3:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP11]])
+; CHECK-SVE-NEXT:    [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP4]], [[TMP5]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE4]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE3]], <16 x i32> [[TMP12]])
+; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-SVE-NEXT:    [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-SVE-NEXT:    br i1 [[TMP22]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]]
 ; CHECK-SVE:       middle.block:
-; CHECK-SVE-NEXT:    [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP21]])
+; CHECK-SVE-NEXT:    [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE4]])
 ; CHECK-SVE-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
 ; CHECK-SVE-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
 ; CHECK-SVE:       scalar.ph:
@@ -996,36 +988,32 @@ define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #
 ; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
 ; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
 ; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
-; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
-; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-SVE:       vector.ph:
-; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 4
-; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
 ; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
 ; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-SVE:       vector.body:
 ; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-SVE-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
-; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP8]], align 1
-; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP9]], align 1
-; CHECK-SVE-NEXT:    [[TMP13:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP14:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP15:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
-; CHECK-SVE-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP14]]
-; CHECK-SVE-NEXT:    [[TMP17:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP16]]
-; CHECK-SVE-NEXT:    [[TMP18]] = add <vscale x 4 x i32> [[TMP17]], [[TMP15]]
-; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
+; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP9]], align 1
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-SVE-NEXT:    [[TMP6:%.*]] = mul nsw <16 x i32> [[TMP3]], [[TMP4]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP6]])
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP5]])
+; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
 ; CHECK-SVE-NEXT:    [[TMP19:%.*]] = icmp eq i64 [[IND...
[truncated]

@SamTebbs33 SamTebbs33 self-requested a review September 24, 2025 15:45
@sdesmalen-arm sdesmalen-arm force-pushed the improve-partial-reduce-costmodel branch from 5a9c0a3 to 5b0c1ec Compare September 30, 2025 13:46
unsigned VFMinValue = VF.getKnownMinValue();
unsigned Ratio =
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
if (VF.getKnownMinValue() < Ratio)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we still want to be catching the cases where the VF is equal to the ratio? Or is that covered by the type legalisation costs below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type legalisation code below checks most of it, but apparently lets through reductions into <1 x i64>, which I guess we don't want to support (as that would just be an in-loop reduction). Good catch!

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still reviewing the legalisation code, but I've left the comments I have so far ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this change can be reverted now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks!


if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;

// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this even possible? The only way we could get here with a scalable VF is if the target said it supported scalable vectors, which presumably means we also have SVE or SME and we're in streaming mode. If there is a way to expose this via a test that would be great!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it from the existing code, but you're right that this is a bogus check, this can never happen. I've removed it.

VectorType *AccumVectorType =
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
// We don't yet support all kinds of legalization (e.g. widening
// of <[vscale x] 1 x ..> accumulators)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this specific example (widening a <vscale x 1 x ...> type) ever happen given the check above?

  if (VF.getKnownMinValue() <= Ratio)
    return Invalid;

Also, the VF should always be a power of 2 (and hence ratio should be a power of 2), which means we shouldn't really end up with TypeWidenVector. Perhaps a better example would be something like <2 x i128> where presumably we'd see a TypeExpandInteger action? Or maybe if we ever support FP element types we'd end up with TypePromoteFloat?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this specific example (widening a <vscale x 1 x ...> type) ever happen given the check above

Not anymore, but I think specifically for the scalable case, we'd want to allow that case in the future when we add widening for partial reductions.

At the moment though, I think it still makes sense to return Invalid in case it has to do any kind of legalisation that isn't yet supported.

Either way, you're right that the comment needs updating.

if (!ST->isSVEorStreamingSVEAvailable())
return Invalid;
// Prefer using full types by costing half-full input types as more expensive.
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also need to check if the input vector type is scalable first? Otherwise you're potentially asking if a legal 64-bit vector is less than a legal SVE vector, where the answer is always going to be true. I guess this may be a valid thing to do if we're going to lower to SVE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was actually the point. We only want to do this for fully packed scalable/fixed vectors, so a 64-bit fixed-length vector would be one we'd prefer not to favour.

This cost-model takes into account any type-legalisation
that would happen on vectors such as splitting and promotion.
This results in wider VFs being chosen for loops that can
use partial reductions.

The cost-model now also assumes that when SVE is available, the SVE dot
instructions for i16 -> i64 dot products can be used for fixed-length
vectors. In practice this means that loops with non-scalable VFs are
vectorized using partial reductions where they wouldn't before, e.g.

  int64_t foo2(int8_t *src1, int8_t *src2, int N) {
    int64_t sum = 0;
    for (int i=0; i<N; ++i)
      sum += (int64_t)src1[i] * (int64_t)src2[i];
    return sum;
  }

These changes also fix an issue where previously a partial reduction
would be used for mixed sign/zero-extends (USDOT), even when
+i8mm was not available.
@sdesmalen-arm sdesmalen-arm force-pushed the improve-partial-reduce-costmodel branch from 034a54e to 03c0f18 Compare October 2, 2025 13:36
Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sdesmalen-arm sdesmalen-arm merged commit cc9c64d into llvm:main Oct 3, 2025
9 checks passed
MixedMatched pushed a commit to MixedMatched/llvm-project that referenced this pull request Oct 3, 2025
…#158641)

This cost-model takes into account any type-legalisation that would
happen on vectors such as splitting and promotion. This results in wider
VFs being chosen for loops that can use partial reductions.

The cost-model now also assumes that when SVE is available, the SVE dot
instructions for i16 -> i64 dot products can be used for fixed-length
vectors. In practice this means that loops with non-scalable VFs are
vectorized using partial reductions where they wouldn't before, e.g.

```
  int64_t foo2(int8_t *src1, int8_t *src2, int N) {
    int64_t sum = 0;
    for (int i=0; i<N; ++i)
      sum += (int64_t)src1[i] * (int64_t)src2[i];
    return sum;
  }
```

These changes also fix an issue where previously a partial reduction
would be used for mixed sign/zero-extends (USDOT), even when +i8mm was
not available.
sdesmalen-arm added a commit to sdesmalen-arm/llvm-project that referenced this pull request Oct 6, 2025
PR llvm#158641 introduced an issue where i128 accumulator types resulted
in a valid cost, because for a <2 x i128> type the code that
checks for unsupported type legalization would see a type action
of 'TypeSplitVector' which is supported, even though the legalised
type of <1 x i128> would require further scalarization.
sdesmalen-arm added a commit to sdesmalen-arm/llvm-project that referenced this pull request Oct 6, 2025
PR llvm#158641 introduced an issue where i128 accumulator types resulted
in a valid cost, because for a <2 x i128> type the code that
checks for unsupported type legalization would see a type action
of 'TypeSplitVector' which is supported, even though the legalised
type of <1 x i128> would require further scalarization.
sdesmalen-arm added a commit that referenced this pull request Oct 6, 2025
…#162066)

PR #158641 introduced an issue where i128 accumulator types resulted
in a valid cost, because for a <2 x i128> type the code that
checks for unsupported type legalization would see a type action
of 'TypeSplitVector' which is supported, even though the legalised
type of <1 x i128> would require further scalarization.

This fixes #162009
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants