Skip to content

Commit

Permalink
[AArch64] Improve cost of umull from known bits
Browse files Browse the repository at this point in the history
As in D140287, we can now generate umull from mul(zext(x), y) in cases where we
know that the top bits of y are zero. This teaches that to the cost model,
adjusting how isWideningInstruction detects mul operations that can extend both
operands. This helps for constants and other cases where the operands of the
mul are known to be extended, but not directly extends.

Differential Revision: https://reviews.llvm.org/D154936
  • Loading branch information
davemgreen committed Jul 12, 2023
1 parent 6c388e0 commit 1712ae6
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 55 deletions.
106 changes: 53 additions & 53 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Expand Up @@ -1944,22 +1944,23 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
}

bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
ArrayRef<Type *> SrcTys,
ArrayRef<const Value *> Args) {

ArrayRef<const Value *> Args,
Type *SrcOverrideTy) {
// A helper that returns a vector type from the given type. The number of
// elements in type Ty determines the vector width.
auto toVectorTy = [&](Type *ArgTy) {
return VectorType::get(ArgTy->getScalarType(),
cast<VectorType>(DstTy)->getElementCount());
};

// Exit early if DstTy is not a vector type whose elements are at least
// 16-bits wide. SVE doesn't generally have the same set of instructions to
// Exit early if DstTy is not a vector type whose elements are one of [i16,
// i32, i64]. SVE doesn't generally have the same set of instructions to
// perform an extend with the add/sub/mul. There are SMULLB style
// instructions, but they operate on top/bottom, requiring some sort of lane
// interleaving to be used with zext/sext.
if (!useNeonVector(DstTy) || DstTy->getScalarSizeInBits() < 16)
unsigned DstEltSize = DstTy->getScalarSizeInBits();
if (!useNeonVector(DstTy) || Args.size() != 2 ||
(DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
return false;

// Determine if the operation has a widening variant. We consider both the
Expand All @@ -1969,42 +1970,55 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
// TODO: Add additional widening operations (e.g., shl, etc.) once we
// verify that their extending operands are eliminated during code
// generation.
Type *SrcTy = SrcOverrideTy;
switch (Opcode) {
case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
case Instruction::Mul: // SMULL(2), UMULL(2)
// The second operand needs to be an extend
if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
if (!SrcTy)
SrcTy =
toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
} else
return false;
break;
case Instruction::Mul: { // SMULL(2), UMULL(2)
// Both operands need to be extends of the same type.
if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
(isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
if (!SrcTy)
SrcTy =
toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
} else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
// If one of the operands is a Zext and the other has enough zero bits to
// be treated as unsigned, we can still general a umull, meaning the zext
// is free.
KnownBits Known =
computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
if (Args[0]->getType()->getScalarSizeInBits() -
Known.Zero.countLeadingOnes() >
DstTy->getScalarSizeInBits() / 2)
return false;
if (!SrcTy)
SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
DstTy->getScalarSizeInBits() / 2));
} else
return false;
break;
}
default:
return false;
}

// To be a widening instruction (either the "wide" or "long" versions), the
// second operand must be a sign- or zero extend.
if (Args.size() != 2 ||
(!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])))
return false;
auto *Extend = cast<CastInst>(Args[1]);
auto *Arg0 = dyn_cast<CastInst>(Args[0]);

// A mul only has a mull version (not like addw). Both operands need to be
// extending and the same type.
if (Opcode == Instruction::Mul &&
(!Arg0 || Arg0->getOpcode() != Extend->getOpcode() ||
(SrcTys.size() == 2 && SrcTys[0] != SrcTys[1])))
return false;

// Legalize the destination type and ensure it can be used in a widening
// operation.
auto DstTyL = getTypeLegalizationCost(DstTy);
unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
return false;

// Legalize the source type and ensure it can be used in a widening
// operation.
Type *SrcTy =
SrcTys.size() > 0 ? SrcTys.back() : toVectorTy(Extend->getSrcTy());

assert(SrcTy && "Expected some SrcTy");
auto SrcTyL = getTypeLegalizationCost(SrcTy);
unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
Expand All @@ -2018,7 +2032,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,

// Return true if the legalized types have the same number of vector elements
// and the destination element type size is twice that of the source type.
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
}

InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
Expand All @@ -2033,31 +2047,17 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (I && I->hasOneUser()) {
auto *SingleUser = cast<Instruction>(*I->user_begin());
SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
SmallVector<Type *, 2> SrcTys;
for (const Value *Op : Operands) {
auto *Cast = dyn_cast<CastInst>(Op);
if (!Cast)
continue;
// Use provided Src type for I and other casts that have the same source
// type.
if (Op == I || cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
SrcTys.push_back(Src);
else
SrcTys.push_back(Cast->getSrcTy());
}
if (isWideningInstruction(Dst, SingleUser->getOpcode(), SrcTys, Operands)) {
// If the cast is the second operand, it is free. We will generate either
// a "wide" or "long" version of the widening instruction.
if (I == SingleUser->getOperand(1))
return 0;
// If the cast is not the second operand, it will be free if it looks the
// same as the second operand. In this case, we will generate a "long"
// version of the widening instruction.
if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
(Src == Cast->getSrcTy() ||
cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy()))
if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
// For adds only count the second operand as free if both operands are
// extends but not the same operation. (i.e both operands are not free in
// add(sext, zext)).
if (SingleUser->getOpcode() == Instruction::Add) {
if (I == SingleUser->getOperand(1) ||
(isa<CastInst>(SingleUser->getOperand(1)) &&
cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
return 0;
} else // Others are free so long as isWideningInstruction returned true.
return 0;
}
}

Expand Down Expand Up @@ -2680,7 +2680,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
// LT.first = 2 the cost is 28. If both operands are extensions it will not
// need to scalarize so the cost can be cheaper (smull or umull).
// so the cost can be cheaper (smull or umull).
if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, {}, Args))
if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
return LT.first;
return LT.first * 14;
case ISD::ADD:
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Expand Up @@ -58,8 +58,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
};

bool isWideningInstruction(Type *DstTy, unsigned Opcode,
ArrayRef<Type *> SrcTys,
ArrayRef<const Value *> Args);
ArrayRef<const Value *> Args,
Type *SrcOverrideTy = nullptr);

// A helper function called by 'getVectorInstrCost'.
//
Expand Down
24 changes: 24 additions & 0 deletions llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
Expand Up @@ -2087,3 +2087,27 @@ define void @extmulv16(<16 x i8> %i8, <16 x i16> %i16, <16 x i32> %i32, <16 x i6

ret void
}

define void @extmul_const(<8 x i8> %i8, <8 x i16> %i16, <8 x i32> %i32, <8 x i64> %i64) {
; CHECK-LABEL: 'extmul_const'
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sl1_8_16 = sext <8 x i8> %i8 to <8 x i16>
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %asl_8_16 = mul <8 x i16> %sl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16 = zext <8 x i8> %i8 to <8 x i16>
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %azl_8_16 = mul <8 x i16> %zl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16b = zext <8 x i8> %i8 to <8 x i16>
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %and = and <8 x i16> %sl1_8_16, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %aal_8_16 = mul <8 x i16> %zl1_8_16b, %and
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
%sl1_8_16 = sext <8 x i8> %i8 to <8 x i16>
%asl_8_16 = mul <8 x i16> %sl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>

%zl1_8_16 = zext <8 x i8> %i8 to <8 x i16>
%azl_8_16 = mul <8 x i16> %zl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>

%zl1_8_16b = zext <8 x i8> %i8 to <8 x i16>
%and = and <8 x i16> %sl1_8_16, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
%aal_8_16 = mul <8 x i16> %zl1_8_16b, %and

ret void
}

0 comments on commit 1712ae6

Please sign in to comment.