Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8072,6 +8072,19 @@ bool VPRecipeBuilder::getScaledReductions(
if (Op == PHI)
std::swap(Op, PhiOp);

using namespace llvm::PatternMatch;
// If Op is an extend, then it's still a valid partial reduction if the
// extended mul fulfills the other requirements.
// For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
// reduction since the inner extends will be widened. We already have oneUse
// checks on the inner extends so widening them is safe.
std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value())))) {
auto *Cast = cast<CastInst>(Op);
OuterExtKind = TTI::getPartialReductionExtendKind(Cast->getOpcode());
Op = Cast->getOperand(0);
}

// Try and get a scaled reduction from the first non-phi operand.
// If one is found, we use the discovered reduction instruction in
// place of the accumulator for costing.
Expand All @@ -8088,8 +8101,6 @@ bool VPRecipeBuilder::getScaledReductions(
if (PhiOp != PHI)
return false;

using namespace llvm::PatternMatch;

// If the update is a binary operator, check both of its operands to see if
// they are extends. Otherwise, see if the update comes directly from an
// extend.
Expand All @@ -8099,7 +8110,7 @@ bool VPRecipeBuilder::getScaledReductions(
Type *ExtOpTypes[2] = {nullptr};
TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};

auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
auto CollectExtInfo = [this, OuterExtKind, &Exts, &ExtOpTypes,
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
for (const auto &[I, OpI] : enumerate(Ops)) {
const APInt *C;
Expand All @@ -8120,6 +8131,10 @@ bool VPRecipeBuilder::getScaledReductions(

ExtOpTypes[I] = ExtOp->getType();
ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]);
// The outer extend kind must be the same as the inner extends, so that
// they can be folded together.
if (OuterExtKind.has_value() && OuterExtKind.value() != ExtKinds[I])
return false;
}
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,83 +467,3 @@ loop:
exit:
ret i32 %red.next
}

define i64 @partial_reduction_mul_two_users(i64 %n, ptr %a, i16 %b, i32 %c) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to remove this test because it no longer tests what we want it to test. It's been replaced by a print test that makes sure that an ExtendedReduction is created instead of an ExtendedMulAccReduction.

; CHECK-LABEL: define i64 @partial_reduction_mul_two_users(
; CHECK-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i16 [[B:%.*]], i32 [[C:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[N]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i16> poison, i16 [[B]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT]], <8 x i16> poison, <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[TMP1]]
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr [[A]], align 2
; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <8 x i16> poison, i16 [[TMP4]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT1]], <8 x i16> poison, <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i32> [[TMP2]] to <8 x i64>
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i64> @llvm.vector.partial.reduce.add.v4i64.v8i64(<4 x i64> [[VEC_PHI]], <8 x i64> [[TMP3]])
; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT2]] to <8 x i32>
; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i32> [[TMP5]] to <8 x i64>
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP7]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[PARTIAL_REDUCE]])
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP6]], i32 7
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
; CHECK: [[SCALAR_PH]]:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ [[VECTOR_RECUR_EXTRACT]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP8]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: [[RES1:%.*]] = phi i64 [ [[SCALAR_RECUR_INIT]], %[[SCALAR_PH]] ], [ [[LOAD_EXT_EXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: [[RES2:%.*]] = phi i64 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD:%.*]], %[[LOOP]] ]
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[A]], align 2
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[B]] to i32
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[CONV]], [[CONV]]
; CHECK-NEXT: [[MUL_EXT:%.*]] = zext i32 [[MUL]] to i64
; CHECK-NEXT: [[ADD]] = add i64 [[RES2]], [[MUL_EXT]]
; CHECK-NEXT: [[OR:%.*]] = or i32 [[MUL]], [[C]]
; CHECK-NEXT: [[LOAD_EXT:%.*]] = sext i16 [[LOAD]] to i32
; CHECK-NEXT: [[LOAD_EXT_EXT]] = sext i32 [[LOAD_EXT]] to i64
; CHECK-NEXT: [[EXITCOND740_NOT:%.*]] = icmp eq i64 [[IV]], [[N]]
; CHECK-NEXT: br i1 [[EXITCOND740_NOT]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP19:![0-9]+]]
; CHECK: [[EXIT]]:
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i64 [ [[ADD]], %[[LOOP]] ], [ [[TMP8]], %[[MIDDLE_BLOCK]] ]
; CHECK-NEXT: ret i64 [[ADD_LCSSA]]
;
entry:
br label %loop

loop:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
%res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ]
%res2 = phi i64 [ 0, %entry ], [ %add, %loop ]
%load = load i16, ptr %a, align 2
%iv.next = add i64 %iv, 1
%conv = sext i16 %b to i32
%mul = mul i32 %conv, %conv
%mul.ext = zext i32 %mul to i64
%add = add i64 %res2, %mul.ext
%second_use = or i32 %mul, %c ; this value is otherwise unused, but that's sufficient for the test
%load.ext = sext i16 %load to i32
%load.ext.ext = sext i32 %load.ext to i64
%exitcond740.not = icmp eq i64 %iv, %n
br i1 %exitcond740.not, label %exit, label %loop

exit:
ret i64 %add
}
Loading