diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index e5d6c8118eb55..7fa787bc9befd 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7954,6 +7954,13 @@ bool VPRecipeBuilder::getScaledReductions( auto CollectExtInfo = [this, &Exts, &ExtOpTypes, &ExtKinds](SmallVectorImpl &Ops) -> bool { for (const auto &[I, OpI] : enumerate(Ops)) { + auto *CI = dyn_cast(OpI); + if (I > 0 && CI && + canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) { + ExtOpTypes[I] = ExtOpTypes[0]; + ExtKinds[I] = ExtKinds[0]; + continue; + } Value *ExtOp; if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp)))) return false; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 02eb6375aac41..07b191a787806 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1753,6 +1753,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) { } #endif +bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, + TTI::PartialReductionExtendKind ExtKind) { + APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits()); + unsigned WideSize = CI->getType()->getScalarSizeInBits(); + APInt ExtendedVal = ExtKind == TTI::PR_SignExtend + ? TruncatedVal.sext(WideSize) + : TruncatedVal.zext(WideSize); + return ExtendedVal == CI->getValue(); +} + TargetTransformInfo::OperandValueInfo VPCostContext::getOperandInfo(VPValue *V) const { if (!V->isLiveIn()) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index fe59774b7c838..fc1a09e9850f6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -468,6 +468,10 @@ class VPlanPrinter { }; #endif +/// Check if a constant \p CI can be safely treated as having been extended +/// from a narrower type with the given extension kind. +bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, + TTI::PartialReductionExtendKind ExtKind); } // end namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 46909a53a9547..67b9244e9dc72 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -340,6 +340,14 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, : Widen->getOperand(1)); ExtAType = GetExtendKind(ExtAR); ExtBType = GetExtendKind(ExtBR); + + if (!ExtBR && Widen->getOperand(1)->isLiveIn()) { + auto *CI = cast(Widen->getOperand(1)->getLiveInIRValue()); + if (canConstantBeExtended(CI, InputTypeA, ExtAType)) { + InputTypeB = InputTypeA; + ExtBType = ExtAType; + } + } }; if (isa(OpR)) { diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll index 0086f6e61cd36..b033f6051f812 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll @@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 -; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: ; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ] -; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ] @@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: [[EXIT]]: -; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ] +; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]] ; entry: @@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 ; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: @@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1 ; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32> ; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63) -; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]] +; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]]) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16 -; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: -; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]]) ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] ; CHECK: [[SCALAR_PH]]: ; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ] -; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] +; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ] @@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) { ; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]] ; CHECK: [[EXIT]]: -; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ] +; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]] ; entry: