From 13aeeb7cd5debc910c1c1af126a5ec43edad0ec6 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 25 Sep 2025 14:09:34 +0100 Subject: [PATCH 1/2] [LV] Support multiplies by constants when forming scaled reductions. We can create partial reductions for multiplies with constants, if the constant is small enough to be extended from source to destination type w/o changing the value. This only handles constant on the right side of a multiply, relying on other passes to canonicalize the input. Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6 --- .../Transforms/Vectorize/LoopVectorize.cpp | 7 ++++ llvm/lib/Transforms/Vectorize/VPlan.cpp | 10 ++++++ llvm/lib/Transforms/Vectorize/VPlanHelpers.h | 4 +++ .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 9 +++++ .../AArch64/partial-reduce-constant-ops.ll | 34 +++++++++---------- 5 files changed, 47 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 12fb46da8e71a..b55685f735929 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7937,6 +7937,13 @@ bool VPRecipeBuilder::getScaledReductions( auto CollectExtInfo = [this, &Exts, &ExtOpTypes, &ExtKinds](SmallVectorImpl &Ops) -> bool { for (const auto &[I, OpI] : enumerate(Ops)) { + auto *CI = dyn_cast(OpI); + if (I > 0 && CI && + canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) { + ExtOpTypes[I] = ExtOpTypes[0]; + ExtKinds[I] = ExtKinds[0]; + continue; + } Value *ExtOp; if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp)))) return false; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 81f1956c96254..6273c066b00d4 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1741,6 +1741,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) { } #endif +bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, + TTI::PartialReductionExtendKind ExtKind) { + APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits()); + unsigned WideSize = CI->getType()->getScalarSizeInBits(); + APInt ExtendedVal = ExtKind == TTI::PR_SignExtend + ? TruncatedVal.sext(WideSize) + : TruncatedVal.zext(WideSize); + return ExtendedVal == CI->getValue(); +} + TargetTransformInfo::OperandValueInfo VPCostContext::getOperandInfo(VPValue *V) const { if (!V->isLiveIn()) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index fe59774b7c838..fc1a09e9850f6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -468,6 +468,10 @@ class VPlanPrinter { }; #endif +/// Check if a constant \p CI can be safely treated as having been extended +/// from a narrower type with the given extension kind. +bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, + TTI::PartialReductionExtendKind ExtKind); } // end namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 46909a53a9547..a9e8bd17d0ae1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -340,6 +340,15 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, : Widen->getOperand(1)); ExtAType = GetExtendKind(ExtAR); ExtBType = GetExtendKind(ExtBR); + + if (!ExtBR && Widen->getOperand(1)->isLiveIn()) { + auto *CI = + dyn_cast(Widen->getOperand(1)->getLiveInIRValue()); + if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) { + InputTypeB = InputTypeA; + ExtBType = ExtAType; + } + } }; if (isa(OpR)) { diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll index 0086f6e61cd36..b033f6051f812 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll @@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 -; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: ; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ] -; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ] @@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: [[EXIT]]: -; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ] +; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]] ; entry: @@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 ; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: @@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 -; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: ; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ] -; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ] @@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]] ; CHECK: [[EXIT]]: -; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ] +; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]] ; entry: From a30869fbad5ca4216b5234e3d725dfe1e6727dfe Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 2 Oct 2025 09:31:30 +0100 Subject: [PATCH 2/2] !fixup dyn_cast->cast, thanks! --- llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index a9e8bd17d0ae1..67b9244e9dc72 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -342,9 +342,8 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, ExtBType = GetExtendKind(ExtBR); if (!ExtBR && Widen->getOperand(1)->isLiveIn()) { - auto *CI = - dyn_cast(Widen->getOperand(1)->getLiveInIRValue()); - if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) { + auto *CI = cast(Widen->getOperand(1)->getLiveInIRValue()); + if (canConstantBeExtended(CI, InputTypeA, ExtAType)) { InputTypeB = InputTypeA; ExtBType = ExtAType; }