-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64] Refactor and refine cost-model for partial reductions #158641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b58b094
0a91b0d
bf2645a
03c0f18
5a464a3
224bbed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( | |
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp, | ||
TTI::TargetCostKind CostKind) const { | ||
InstructionCost Invalid = InstructionCost::getInvalid(); | ||
InstructionCost Cost(TTI::TCC_Basic); | ||
|
||
if (CostKind != TTI::TCK_RecipThroughput) | ||
return Invalid; | ||
|
||
// Sub opcodes currently only occur in chained cases. | ||
// Independent partial reduction subtractions are still costed as an add | ||
if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() && | ||
(!ST->isNeonAvailable() || !ST->hasDotProd())) | ||
return Invalid; | ||
|
||
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) || | ||
OpAExtend == TTI::PR_None) | ||
return Invalid; | ||
|
||
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && | ||
(!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) && | ||
"Unexpected values for OpBExtend or InputTypeB"); | ||
|
||
// We only support multiply binary operations for now, and for muls we | ||
// require the types being extended to be the same. | ||
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but | ||
// only if the i8mm or sve/streaming features are available. | ||
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB || | ||
OpBExtend == TTI::PR_None || | ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() && | ||
!ST->isSVEorStreamingSVEAvailable()))) | ||
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB)) | ||
return Invalid; | ||
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && | ||
"Unexpected values for OpBExtend or InputTypeB"); | ||
|
||
EVT InputEVT = EVT::getEVT(InputTypeA); | ||
EVT AccumEVT = EVT::getEVT(AccumType); | ||
bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend; | ||
if (IsUSDot && !ST->hasMatMulInt8()) | ||
return Invalid; | ||
|
||
unsigned Ratio = | ||
AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits(); | ||
if (VF.getKnownMinValue() <= Ratio) | ||
return Invalid; | ||
|
||
VectorType *InputVectorType = VectorType::get(InputTypeA, VF); | ||
VectorType *AccumVectorType = | ||
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio)); | ||
// We don't yet support all kinds of legalization. | ||
auto TA = TLI->getTypeAction(AccumVectorType->getContext(), | ||
EVT::getEVT(AccumVectorType)); | ||
switch (TA) { | ||
default: | ||
return Invalid; | ||
case TargetLowering::TypeLegal: | ||
case TargetLowering::TypePromoteInteger: | ||
case TargetLowering::TypeSplitVector: | ||
break; | ||
} | ||
|
||
// Check what kind of type-legalisation happens. | ||
std::pair<InstructionCost, MVT> AccumLT = | ||
getTypeLegalizationCost(AccumVectorType); | ||
std::pair<InstructionCost, MVT> InputLT = | ||
getTypeLegalizationCost(InputVectorType); | ||
|
||
unsigned VFMinValue = VF.getKnownMinValue(); | ||
InstructionCost Cost = InputLT.first * TTI::TCC_Basic; | ||
|
||
if (VF.isScalable()) { | ||
if (!ST->isSVEorStreamingSVEAvailable()) | ||
return Invalid; | ||
// Prefer using full types by costing half-full input types as more expensive. | ||
if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you also need to check if the input vector type is scalable first? Otherwise you're potentially asking if a legal 64-bit vector is less than a legal SVE vector, where the answer is always going to be true. I guess this may be a valid thing to do if we're going to lower to SVE? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that was actually the point. We only want to do this for fully packed scalable/fixed vectors, so a 64-bit fixed-length vector would be one we'd prefer not to favour. |
||
TypeSize::getScalable(128))) | ||
// FIXME: This can be removed after the cost of the extends are folded into | ||
// the dot-product expression in VPlan, after landing: | ||
// https://github.com/llvm/llvm-project/pull/147302 | ||
Cost *= 2; | ||
|
||
// Don't accept a partial reduction if the scaled accumulator is vscale x 1, | ||
// since we can't lower that type. | ||
unsigned Scale = | ||
AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits(); | ||
if (VFMinValue == Scale) | ||
return Invalid; | ||
if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) { | ||
// i16 -> i64 is natively supported for udot/sdot | ||
if (AccumLT.second.getScalarType() == MVT::i64 && | ||
InputLT.second.getScalarType() == MVT::i16) | ||
return Cost; | ||
// i8 -> i64 is supported with an extra level of extends | ||
if (AccumLT.second.getScalarType() == MVT::i64 && | ||
InputLT.second.getScalarType() == MVT::i8) | ||
// FIXME: This cost should probably be a little higher, e.g. Cost + 2 | ||
// because it requires two extra extends on the inputs. But if we'd change | ||
// that now, a regular reduction would be cheaper because the costs of | ||
// the extends in the IR are still counted. This can be fixed | ||
// after https://github.com/llvm/llvm-project/pull/147302 has landed. | ||
return Cost; | ||
} | ||
if (VF.isFixed() && | ||
(!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64)) | ||
return Invalid; | ||
|
||
if (InputEVT == MVT::i8) { | ||
switch (VFMinValue) { | ||
default: | ||
return Invalid; | ||
case 8: | ||
if (AccumEVT == MVT::i32) | ||
Cost *= 2; | ||
else if (AccumEVT != MVT::i64) | ||
return Invalid; | ||
break; | ||
case 16: | ||
if (AccumEVT == MVT::i64) | ||
Cost *= 2; | ||
else if (AccumEVT != MVT::i32) | ||
return Invalid; | ||
break; | ||
} | ||
} else if (InputEVT == MVT::i16) { | ||
// FIXME: Allow i32 accumulator but increase cost, as we would extend | ||
// it to i64. | ||
if (VFMinValue != 8 || AccumEVT != MVT::i64) | ||
return Invalid; | ||
} else | ||
return Invalid; | ||
// i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE. | ||
if (ST->isSVEorStreamingSVEAvailable() || | ||
(AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() && | ||
ST->hasDotProd())) { | ||
if (AccumLT.second.getScalarType() == MVT::i32 && | ||
InputLT.second.getScalarType() == MVT::i8) | ||
return Cost; | ||
} | ||
|
||
return Cost; | ||
// Add additional cost for the extends that would need to be inserted. | ||
return Cost + 4; | ||
} | ||
|
||
InstructionCost | ||
|
Uh oh!
There was an error while loading. Please reload this page.