Skip to content

Commit

Permalink
Clean up usages of asserting vector getters in Type
Browse files Browse the repository at this point in the history
Summary:
Remove usages of asserting vector getters in Type in preparation for the
VectorType refactor. The existence of these functions complicates the
refactor while adding little value.

Reviewers: sdesmalen, rriddle, efriedma

Reviewed By: sdesmalen

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77263
  • Loading branch information
christetreault-llvm committed Apr 8, 2020
1 parent ff1658b commit 155740c
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 129 deletions.
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Expand Up @@ -1652,7 +1652,7 @@ static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) {

if (C->getType()->isVectorTy()) {
// Check each element of a constant vector.
unsigned NumElts = C->getType()->getVectorNumElements();
unsigned NumElts = cast<VectorType>(C->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
Expand Down Expand Up @@ -2082,7 +2082,7 @@ static Instruction *matchRotate(Instruction &Or) {

/// If all elements of two constant vectors are 0/-1 and inverses, return true.
static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
unsigned NumElts = C1->getType()->getVectorNumElements();
unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *EltC1 = C1->getAggregateElement(i);
Constant *EltC2 = C2->getAggregateElement(i);
Expand Down
65 changes: 35 additions & 30 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Expand Up @@ -523,7 +523,7 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,
auto Vec = II.getArgOperand(0);
auto Amt = II.getArgOperand(1);
auto VT = cast<VectorType>(II.getType());
auto SVT = VT->getVectorElementType();
auto SVT = VT->getElementType();
int NumElts = VT->getNumElements();
int BitWidth = SVT->getIntegerBitWidth();

Expand Down Expand Up @@ -620,10 +620,10 @@ static Value *simplifyX86pack(IntrinsicInst &II,
if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
return UndefValue::get(ResTy);

Type *ArgTy = Arg0->getType();
auto *ArgTy = cast<VectorType>(Arg0->getType());
unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128;
unsigned NumSrcElts = ArgTy->getVectorNumElements();
assert(ResTy->getVectorNumElements() == (2 * NumSrcElts) &&
unsigned NumSrcElts = ArgTy->getNumElements();
assert(cast<VectorType>(ResTy)->getNumElements() == (2 * NumSrcElts) &&
"Unexpected packing types");

unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes;
Expand Down Expand Up @@ -680,23 +680,23 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder) {
Value *Arg = II.getArgOperand(0);
Type *ResTy = II.getType();
Type *ArgTy = Arg->getType();

// movmsk(undef) -> zero as we must ensure the upper bits are zero.
if (isa<UndefValue>(Arg))
return Constant::getNullValue(ResTy);

auto *ArgTy = dyn_cast<VectorType>(Arg->getType());
// We can't easily peek through x86_mmx types.
if (!ArgTy->isVectorTy())
if (!ArgTy)
return nullptr;

// Expand MOVMSK to compare/bitcast/zext:
// e.g. PMOVMSKB(v16i8 x):
// %cmp = icmp slt <16 x i8> %x, zeroinitializer
// %int = bitcast <16 x i1> %cmp to i16
// %res = zext i16 %int to i32
unsigned NumElts = ArgTy->getVectorNumElements();
Type *IntegerVecTy = VectorType::getInteger(cast<VectorType>(ArgTy));
unsigned NumElts = ArgTy->getNumElements();
Type *IntegerVecTy = VectorType::getInteger(ArgTy);
Type *IntegerTy = Builder.getIntNTy(NumElts);

Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy);
Expand Down Expand Up @@ -1036,7 +1036,7 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II,

auto *VecTy = cast<VectorType>(II.getType());
auto *MaskEltTy = Type::getInt32Ty(II.getContext());
unsigned NumElts = VecTy->getVectorNumElements();
unsigned NumElts = VecTy->getNumElements();
bool IsPD = VecTy->getScalarType()->isDoubleTy();
unsigned NumLaneElts = IsPD ? 2 : 4;
assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2);
Expand Down Expand Up @@ -1955,8 +1955,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}

// For vector result intrinsics, use the generic demanded vector support.
if (II->getType()->isVectorTy()) {
auto VWidth = II->getType()->getVectorNumElements();
if (auto *IIVTy = dyn_cast<VectorType>(II->getType())) {
auto VWidth = IIVTy->getNumElements();
APInt UndefElts(VWidth, 0);
APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) {
Expand Down Expand Up @@ -2505,8 +2505,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Turn PPC QPX qvlfs -> load if the pointer is known aligned.
if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC,
&DT) >= 16) {
Type *VTy = VectorType::get(Builder.getFloatTy(),
II->getType()->getVectorNumElements());
Type *VTy =
VectorType::get(Builder.getFloatTy(),
cast<VectorType>(II->getType())->getElementCount());
Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
PointerType::getUnqual(VTy));
Value *Load = Builder.CreateLoad(VTy, Ptr);
Expand All @@ -2526,8 +2527,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// Turn PPC QPX qvstfs -> store if the pointer is known aligned.
if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC,
&DT) >= 16) {
Type *VTy = VectorType::get(Builder.getFloatTy(),
II->getArgOperand(0)->getType()->getVectorNumElements());
Type *VTy = VectorType::get(
Builder.getFloatTy(),
cast<VectorType>(II->getArgOperand(0)->getType())->getElementCount());
Value *TOp = Builder.CreateFPTrunc(II->getArgOperand(0), VTy);
Type *OpPtrTy = PointerType::getUnqual(VTy);
Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
Expand Down Expand Up @@ -2676,7 +2678,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// These intrinsics only demand the 0th element of their input vectors. If
// we can simplify the input based on that, do so now.
Value *Arg = II->getArgOperand(0);
unsigned VWidth = Arg->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Arg->getType())->getNumElements();
if (Value *V = SimplifyDemandedVectorEltsLow(Arg, VWidth, 1))
return replaceOperand(*II, 0, V);
break;
Expand Down Expand Up @@ -2726,7 +2728,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
bool MadeChange = false;
Value *Arg0 = II->getArgOperand(0);
Value *Arg1 = II->getArgOperand(1);
unsigned VWidth = Arg0->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements();
if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) {
replaceOperand(*II, 0, V);
MadeChange = true;
Expand Down Expand Up @@ -2944,7 +2946,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
Value *Arg1 = II->getArgOperand(1);
assert(Arg1->getType()->getPrimitiveSizeInBits() == 128 &&
"Unexpected packed shift size");
unsigned VWidth = Arg1->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Arg1->getType())->getNumElements();

if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, VWidth / 2))
return replaceOperand(*II, 1, V);
Expand Down Expand Up @@ -3011,7 +3013,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
bool MadeChange = false;
Value *Arg0 = II->getArgOperand(0);
Value *Arg1 = II->getArgOperand(1);
unsigned VWidth = Arg0->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements();

APInt UndefElts1(VWidth, 0);
APInt DemandedElts1 = APInt::getSplat(VWidth,
Expand Down Expand Up @@ -3051,8 +3053,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_sse4a_extrq: {
Value *Op0 = II->getArgOperand(0);
Value *Op1 = II->getArgOperand(1);
unsigned VWidth0 = Op0->getType()->getVectorNumElements();
unsigned VWidth1 = Op1->getType()->getVectorNumElements();
unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements();
unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements();
assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 &&
VWidth1 == 16 && "Unexpected operand sizes");
Expand Down Expand Up @@ -3090,7 +3092,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// EXTRQI: Extract Length bits starting from Index. Zero pad the remaining
// bits of the lower 64-bits. The upper 64-bits are undefined.
Value *Op0 = II->getArgOperand(0);
unsigned VWidth = Op0->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements();
assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 &&
"Unexpected operand size");

Expand All @@ -3112,10 +3114,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_sse4a_insertq: {
Value *Op0 = II->getArgOperand(0);
Value *Op1 = II->getArgOperand(1);
unsigned VWidth = Op0->getType()->getVectorNumElements();
unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements();
assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 &&
Op1->getType()->getVectorNumElements() == 2 &&
cast<VectorType>(Op1->getType())->getNumElements() == 2 &&
"Unexpected operand size");

// See if we're dealing with constant values.
Expand Down Expand Up @@ -3146,8 +3148,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// undefined.
Value *Op0 = II->getArgOperand(0);
Value *Op1 = II->getArgOperand(1);
unsigned VWidth0 = Op0->getType()->getVectorNumElements();
unsigned VWidth1 = Op1->getType()->getVectorNumElements();
unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements();
unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements();
assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 &&
VWidth1 == 2 && "Unexpected operand sizes");
Expand Down Expand Up @@ -3214,8 +3216,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
II->getType()->getPrimitiveSizeInBits() &&
"Not expecting mask and operands with different sizes");

unsigned NumMaskElts = Mask->getType()->getVectorNumElements();
unsigned NumOperandElts = II->getType()->getVectorNumElements();
unsigned NumMaskElts =
cast<VectorType>(Mask->getType())->getNumElements();
unsigned NumOperandElts =
cast<VectorType>(II->getType())->getNumElements();
if (NumMaskElts == NumOperandElts)
return SelectInst::Create(BoolVec, Op1, Op0);

Expand Down Expand Up @@ -3306,7 +3310,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// the permutation mask with respect to 31 and reverse the order of
// V1 and V2.
if (Constant *Mask = dyn_cast<Constant>(II->getArgOperand(2))) {
assert(Mask->getType()->getVectorNumElements() == 16 &&
assert(cast<VectorType>(Mask->getType())->getNumElements() == 16 &&
"Bad type for intrinsic!");

// Check that all of the elements are integer constants or undefs.
Expand Down Expand Up @@ -3464,7 +3468,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
if (auto *CI = dyn_cast<ConstantInt>(XorMask)) {
if (CI->getValue().trunc(16).isAllOnesValue()) {
auto TrueVector = Builder.CreateVectorSplat(
II->getType()->getVectorNumElements(), Builder.getTrue());
cast<VectorType>(II->getType())->getNumElements(),
Builder.getTrue());
return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector);
}
}
Expand Down
29 changes: 16 additions & 13 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Expand Up @@ -856,10 +856,10 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
Value *VecOp;
if (match(Src,
m_OneUse(m_ExtractElement(m_Value(VecOp), m_ConstantInt(Cst))))) {
Type *VecOpTy = VecOp->getType();
auto *VecOpTy = cast<VectorType>(VecOp->getType());
unsigned DestScalarSize = DestTy->getScalarSizeInBits();
unsigned VecOpScalarSize = VecOpTy->getScalarSizeInBits();
unsigned VecNumElts = VecOpTy->getVectorNumElements();
unsigned VecNumElts = VecOpTy->getNumElements();

// A badly fit destination size would result in an invalid cast.
if (VecOpScalarSize % DestScalarSize == 0) {
Expand Down Expand Up @@ -1514,12 +1514,13 @@ static Type *shrinkFPConstant(ConstantFP *CFP) {
// TODO: Make these support undef elements.
static Type *shrinkFPConstantVector(Value *V) {
auto *CV = dyn_cast<Constant>(V);
if (!CV || !CV->getType()->isVectorTy())
auto *CVVTy = dyn_cast<VectorType>(V->getType());
if (!CV || !CVVTy)
return nullptr;

Type *MinType = nullptr;

unsigned NumElts = CV->getType()->getVectorNumElements();
unsigned NumElts = CVVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
if (!CFP)
Expand Down Expand Up @@ -1820,8 +1821,9 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
if (CI.getOperand(0)->getType()->getScalarSizeInBits() !=
DL.getPointerSizeInBits(AS)) {
Type *Ty = DL.getIntPtrType(CI.getContext(), AS);
if (CI.getType()->isVectorTy()) // Handle vectors of pointers.
Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements());
// Handle vectors of pointers.
if (auto *CIVTy = dyn_cast<VectorType>(CI.getType()))
Ty = VectorType::get(Ty, CIVTy->getElementCount());

Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty);
return new IntToPtrInst(P, CI.getType());
Expand Down Expand Up @@ -1868,8 +1870,8 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
return commonPointerCastTransforms(CI);

Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS);
if (Ty->isVectorTy()) // Handle vectors of pointers.
PtrTy = VectorType::get(PtrTy, Ty->getVectorNumElements());
if (auto *VTy = dyn_cast<VectorType>(Ty)) // Handle vectors of pointers.
PtrTy = VectorType::get(PtrTy, VTy->getNumElements());

Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy);
return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
Expand Down Expand Up @@ -2199,10 +2201,10 @@ static Instruction *foldBitCastSelect(BitCastInst &BitCast,
// A vector select must maintain the same number of elements in its operands.
Type *CondTy = Cond->getType();
Type *DestTy = BitCast.getType();
if (CondTy->isVectorTy()) {
if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) {
if (!DestTy->isVectorTy())
return nullptr;
if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements())
if (cast<VectorType>(DestTy)->getNumElements() != CondVTy->getNumElements())
return nullptr;
}

Expand Down Expand Up @@ -2536,10 +2538,11 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
// a bitcast to a vector with the same # elts.
Value *ShufOp0 = Shuf->getOperand(0);
Value *ShufOp1 = Shuf->getOperand(1);
unsigned NumShufElts = Shuf->getType()->getVectorNumElements();
unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements();
unsigned NumShufElts = Shuf->getType()->getNumElements();
unsigned NumSrcVecElts =
cast<VectorType>(ShufOp0->getType())->getNumElements();
if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
DestTy->getVectorNumElements() == NumShufElts &&
cast<VectorType>(DestTy)->getNumElements() == NumShufElts &&
NumShufElts == NumSrcVecElts) {
BitCastInst *Tmp;
// If either of the operands is a cast from CI.getType(), then
Expand Down
24 changes: 13 additions & 11 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Expand Up @@ -897,7 +897,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// For vectors, we apply the same reasoning on a per-lane basis.
auto *Base = GEPLHS->getPointerOperand();
if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
int NumElts = GEPLHS->getType()->getVectorNumElements();
int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements();
Base = Builder.CreateVectorSplat(NumElts, Base);
}
return new ICmpInst(Cond, Base,
Expand Down Expand Up @@ -1861,8 +1861,8 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
int32_t ExactLogBase2 = C2->exactLogBase2();
if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1);
if (And->getType()->isVectorTy())
NTy = VectorType::get(NTy, And->getType()->getVectorNumElements());
if (auto *AndVTy = dyn_cast<VectorType>(And->getType()))
NTy = VectorType::get(NTy, AndVTy->getNumElements());
Value *Trunc = Builder.CreateTrunc(X, NTy);
auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE
: CmpInst::ICMP_SLT;
Expand Down Expand Up @@ -2147,8 +2147,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
if (ShType->isVectorTy())
TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());
if (auto *ShVTy = dyn_cast<VectorType>(ShType))
TruncTy = VectorType::get(TruncTy, ShVTy->getNumElements());
Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
Expand Down Expand Up @@ -2776,8 +2776,8 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
// (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1
Type *XType = X->getType();
Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits());
if (XType->isVectorTy())
NewType = VectorType::get(NewType, XType->getVectorNumElements());
if (auto *XVTy = dyn_cast<VectorType>(XType))
NewType = VectorType::get(NewType, XVTy->getNumElements());
Value *NewBitcast = Builder.CreateBitCast(X, NewType);
if (TrueIfSigned)
return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast,
Expand Down Expand Up @@ -3354,8 +3354,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) {
auto *OpVTy = cast<VectorType>(OpTy);
Constant *SafeReplacementConstant = nullptr;
for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) {
for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
SafeReplacementConstant = VecC->getAggregateElement(i);
break;
Expand Down Expand Up @@ -5189,8 +5190,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
return llvm::None;
} else if (Type->isVectorTy()) {
unsigned NumElts = Type->getVectorNumElements();
} else if (auto *VTy = dyn_cast<VectorType>(Type)) {
unsigned NumElts = VTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
Expand Down Expand Up @@ -5411,7 +5412,8 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) {
// We allow undefs in matching, but this transform removes those for safety.
// Demanded elements analysis should be able to recover some/all of that.
C = ConstantVector::getSplat(V1Ty->getVectorElementCount(), ScalarC);
C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(),
ScalarC);
SmallVector<int, 8> NewM(M.size(), MaskSplatIndex);
Value *NewCmp = IsFP ? Builder.CreateFCmp(Pred, V1, C)
: Builder.CreateICmp(Pred, V1, C);
Expand Down

0 comments on commit 155740c

Please sign in to comment.