Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);

if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;

// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
(!ST->isNeonAvailable() || !ST->hasDotProd()))
return Invalid;

if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
OpAExtend == TTI::PR_None)
return Invalid;

assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
(!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) &&
"Unexpected values for OpBExtend or InputTypeB");

// We only support multiply binary operations for now, and for muls we
// require the types being extended to be the same.
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
// only if the i8mm or sve/streaming features are available.
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable())))
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
return Invalid;
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
"Unexpected values for OpBExtend or InputTypeB");

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);
bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
if (IsUSDot && !ST->hasMatMulInt8())
return Invalid;

unsigned Ratio =
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
if (VF.getKnownMinValue() <= Ratio)
return Invalid;

VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
VectorType *AccumVectorType =
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
// We don't yet support all kinds of legalization.
auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
EVT::getEVT(AccumVectorType));
switch (TA) {
default:
return Invalid;
case TargetLowering::TypeLegal:
case TargetLowering::TypePromoteInteger:
case TargetLowering::TypeSplitVector:
break;
}

// Check what kind of type-legalisation happens.
std::pair<InstructionCost, MVT> AccumLT =
getTypeLegalizationCost(AccumVectorType);
std::pair<InstructionCost, MVT> InputLT =
getTypeLegalizationCost(InputVectorType);

unsigned VFMinValue = VF.getKnownMinValue();
InstructionCost Cost = InputLT.first * TTI::TCC_Basic;

if (VF.isScalable()) {
if (!ST->isSVEorStreamingSVEAvailable())
return Invalid;
// Prefer using full types by costing half-full input types as more expensive.
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also need to check if the input vector type is scalable first? Otherwise you're potentially asking if a legal 64-bit vector is less than a legal SVE vector, where the answer is always going to be true. I guess this may be a valid thing to do if we're going to lower to SVE?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was actually the point. We only want to do this for fully packed scalable/fixed vectors, so a 64-bit fixed-length vector would be one we'd prefer not to favour.

TypeSize::getScalable(128)))
// FIXME: This can be removed after the cost of the extends are folded into
// the dot-product expression in VPlan, after landing:
// https://github.com/llvm/llvm-project/pull/147302
Cost *= 2;

// Don't accept a partial reduction if the scaled accumulator is vscale x 1,
// since we can't lower that type.
unsigned Scale =
AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
if (VFMinValue == Scale)
return Invalid;
if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
// i16 -> i64 is natively supported for udot/sdot
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i16)
return Cost;
// i8 -> i64 is supported with an extra level of extends
if (AccumLT.second.getScalarType() == MVT::i64 &&
InputLT.second.getScalarType() == MVT::i8)
// FIXME: This cost should probably be a little higher, e.g. Cost + 2
// because it requires two extra extends on the inputs. But if we'd change
// that now, a regular reduction would be cheaper because the costs of
// the extends in the IR are still counted. This can be fixed
// after https://github.com/llvm/llvm-project/pull/147302 has landed.
return Cost;
}
if (VF.isFixed() &&
(!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
return Invalid;

if (InputEVT == MVT::i8) {
switch (VFMinValue) {
default:
return Invalid;
case 8:
if (AccumEVT == MVT::i32)
Cost *= 2;
else if (AccumEVT != MVT::i64)
return Invalid;
break;
case 16:
if (AccumEVT == MVT::i64)
Cost *= 2;
else if (AccumEVT != MVT::i32)
return Invalid;
break;
}
} else if (InputEVT == MVT::i16) {
// FIXME: Allow i32 accumulator but increase cost, as we would extend
// it to i64.
if (VFMinValue != 8 || AccumEVT != MVT::i64)
return Invalid;
} else
return Invalid;
// i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
if (ST->isSVEorStreamingSVEAvailable() ||
(AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
ST->hasDotProd())) {
if (AccumLT.second.getScalarType() == MVT::i32 &&
InputLT.second.getScalarType() == MVT::i8)
return Cost;
}

return Cost;
// Add additional cost for the extends that would need to be inserted.
return Cost + 4;
}

InstructionCost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ define i64 @test_two_ivs(ptr %a, ptr %b, i64 %start) #0 {
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %i.iv = phi i64 [ 0, %entry ], [ %i.iv.next, %for.body ]
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
; CHECK-NEXT: Cost of 0 for VF 16: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
; CHECK: Cost for VF 16: 48
; CHECK: Cost for VF 16: 41
; CHECK: LV: Selecting VF: 16
entry:
br label %for.body
Expand Down
Loading