diff --git a/clang/unittests/CodeGen/CodeGenExternalTest.cpp b/clang/unittests/CodeGen/CodeGenExternalTest.cpp index 8dff45c8a0f53a..255b8c3e9d8cdc 100644 --- a/clang/unittests/CodeGen/CodeGenExternalTest.cpp +++ b/clang/unittests/CodeGen/CodeGenExternalTest.cpp @@ -199,7 +199,7 @@ static void test_codegen_fns(MyASTConsumer *my) { dbgs() << "\n"; } - llvm::CompositeType* structTy = dyn_cast(llvmTy); + auto* structTy = dyn_cast(llvmTy); ASSERT_TRUE(structTy != NULL); // Check getLLVMFieldNumber diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index e6d8c0eb4d925a..d0ea1adbbf1838 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -392,7 +392,7 @@ class ConstantAggregateZero final : public ConstantData { /// use operands. class ConstantAggregate : public Constant { protected: - ConstantAggregate(CompositeType *T, ValueTy VT, ArrayRef V); + ConstantAggregate(Type *T, ValueTy VT, ArrayRef V); public: /// Transparently provide more efficient getOperand methods. diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h index 3b56da680c6e2c..ac3abe3c32dc66 100644 --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -195,26 +195,6 @@ class FunctionCallee { Value *Callee = nullptr; }; -/// Common super class of ArrayType, StructType and VectorType. -class CompositeType : public Type { -protected: - explicit CompositeType(LLVMContext &C, TypeID tid) : Type(C, tid) {} - -public: - /// Given an index value into the type, return the type of the element. - Type *getTypeAtIndex(const Value *V) const; - Type *getTypeAtIndex(unsigned Idx) const; - bool indexValid(const Value *V) const; - bool indexValid(unsigned Idx) const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Type *T) { - return T->getTypeID() == ArrayTyID || - T->getTypeID() == StructTyID || - T->getTypeID() == VectorTyID; - } -}; - /// Class to represent struct types. There are two different kinds of struct /// types: Literal structs and Identified structs. /// @@ -235,8 +215,8 @@ class CompositeType : public Type { /// elements as defined by DataLayout (which is required to match what the code /// generator for a target expects). /// -class StructType : public CompositeType { - StructType(LLVMContext &C) : CompositeType(C, StructTyID) {} +class StructType : public Type { + StructType(LLVMContext &C) : Type(C, StructTyID) {} enum { /// This is the contents of the SubClassData field. @@ -350,6 +330,11 @@ class StructType : public CompositeType { assert(N < NumContainedTys && "Element number out of range!"); return ContainedTys[N]; } + /// Given an index value into the type, return the type of the element. + Type *getTypeAtIndex(const Value *V) const; + Type *getTypeAtIndex(unsigned N) const { return getElementType(N); } + bool indexValid(const Value *V) const; + bool indexValid(unsigned Idx) const { return Idx < getNumElements(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { @@ -375,14 +360,14 @@ Type *Type::getStructElementType(unsigned N) const { /// for use of SIMD instructions. SequentialType holds the common features of /// both, which stem from the fact that both lay their components out in memory /// identically. -class SequentialType : public CompositeType { +class SequentialType : public Type { Type *ContainedType; ///< Storage for the single contained type. uint64_t NumElements; protected: SequentialType(TypeID TID, Type *ElType, uint64_t NumElements) - : CompositeType(ElType->getContext(), TID), ContainedType(ElType), - NumElements(NumElements) { + : Type(ElType->getContext(), TID), ContainedType(ElType), + NumElements(NumElements) { ContainedTys = &ContainedType; NumContainedTys = 1; } diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index e558d4317efc3b..713624d13bef0c 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -1008,16 +1008,23 @@ class GetElementPtrInst : public Instruction { return getPointerAddressSpace(); } - /// Returns the type of the element that would be loaded with - /// a load instruction with the specified parameters. + /// Returns the result type of a getelementptr with the given source + /// element type and indexes. /// /// Null is returned if the indices are invalid for the specified - /// pointer type. - /// + /// source element type. static Type *getIndexedType(Type *Ty, ArrayRef IdxList); static Type *getIndexedType(Type *Ty, ArrayRef IdxList); static Type *getIndexedType(Type *Ty, ArrayRef IdxList); + /// Return the type of the element at the given index of an indexable + /// type. This is equivalent to "getIndexedType(Agg, {Zero, Idx})". + /// + /// Returns null if the type can't be indexed, or the given index is not + /// legal for the given type. + static Type *getTypeAtIndex(Type *Ty, Value *Idx); + static Type *getTypeAtIndex(Type *Ty, uint64_t Idx); + inline op_iterator idx_begin() { return op_begin()+1; } inline const_op_iterator idx_begin() const { return op_begin()+1; } inline op_iterator idx_end() { return op_end(); } diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp index 1632895fe5fa3f..2a995922654081 100644 --- a/llvm/lib/CodeGen/Analysis.cpp +++ b/llvm/lib/CodeGen/Analysis.cpp @@ -395,7 +395,7 @@ static bool slotOnlyDiscardsData(const Value *RetVal, const Value *CallVal, /// For an aggregate type, determine whether a given index is within bounds or /// not. -static bool indexReallyValid(CompositeType *T, unsigned Idx) { +static bool indexReallyValid(Type *T, unsigned Idx) { if (ArrayType *AT = dyn_cast(T)) return Idx < AT->getNumElements(); @@ -419,7 +419,7 @@ static bool indexReallyValid(CompositeType *T, unsigned Idx) { /// function again on a finished iterator will repeatedly return /// false. SubTypes.back()->getTypeAtIndex(Path.back()) is either an empty /// aggregate or a non-aggregate -static bool advanceToNextLeafType(SmallVectorImpl &SubTypes, +static bool advanceToNextLeafType(SmallVectorImpl &SubTypes, SmallVectorImpl &Path) { // First march back up the tree until we can successfully increment one of the // coordinates in Path. @@ -435,16 +435,16 @@ static bool advanceToNextLeafType(SmallVectorImpl &SubTypes, // We know there's *some* valid leaf now, so march back down the tree picking // out the left-most element at each node. ++Path.back(); - Type *DeeperType = SubTypes.back()->getTypeAtIndex(Path.back()); + Type *DeeperType = + ExtractValueInst::getIndexedType(SubTypes.back(), Path.back()); while (DeeperType->isAggregateType()) { - CompositeType *CT = cast(DeeperType); - if (!indexReallyValid(CT, 0)) + if (!indexReallyValid(DeeperType, 0)) return true; - SubTypes.push_back(CT); + SubTypes.push_back(DeeperType); Path.push_back(0); - DeeperType = CT->getTypeAtIndex(0U); + DeeperType = ExtractValueInst::getIndexedType(DeeperType, 0); } return true; @@ -460,17 +460,15 @@ static bool advanceToNextLeafType(SmallVectorImpl &SubTypes, /// For example, if Next was {[0 x i64], {{}, i32, {}}, i32} then we would setup /// Path as [1, 1] and SubTypes as [Next, {{}, i32, {}}] to represent the first /// i32 in that type. -static bool firstRealType(Type *Next, - SmallVectorImpl &SubTypes, +static bool firstRealType(Type *Next, SmallVectorImpl &SubTypes, SmallVectorImpl &Path) { // First initialise the iterator components to the first "leaf" node // (i.e. node with no valid sub-type at any index, so {} does count as a leaf // despite nominally being an aggregate). - while (Next->isAggregateType() && - indexReallyValid(cast(Next), 0)) { - SubTypes.push_back(cast(Next)); + while (Type *FirstInner = ExtractValueInst::getIndexedType(Next, 0)) { + SubTypes.push_back(Next); Path.push_back(0); - Next = cast(Next)->getTypeAtIndex(0U); + Next = FirstInner; } // If there's no Path now, Next was originally scalar already (or empty @@ -480,7 +478,8 @@ static bool firstRealType(Type *Next, // Otherwise, use normal iteration to keep looking through the tree until we // find a non-aggregate type. - while (SubTypes.back()->getTypeAtIndex(Path.back())->isAggregateType()) { + while (ExtractValueInst::getIndexedType(SubTypes.back(), Path.back()) + ->isAggregateType()) { if (!advanceToNextLeafType(SubTypes, Path)) return false; } @@ -490,14 +489,15 @@ static bool firstRealType(Type *Next, /// Set the iterator data-structures to the next non-empty, non-aggregate /// subtype. -static bool nextRealType(SmallVectorImpl &SubTypes, +static bool nextRealType(SmallVectorImpl &SubTypes, SmallVectorImpl &Path) { do { if (!advanceToNextLeafType(SubTypes, Path)) return false; assert(!Path.empty() && "found a leaf but didn't set the path?"); - } while (SubTypes.back()->getTypeAtIndex(Path.back())->isAggregateType()); + } while (ExtractValueInst::getIndexedType(SubTypes.back(), Path.back()) + ->isAggregateType()); return true; } @@ -669,7 +669,7 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F, } SmallVector RetPath, CallPath; - SmallVector RetSubTypes, CallSubTypes; + SmallVector RetSubTypes, CallSubTypes; bool RetEmpty = !firstRealType(RetVal->getType(), RetSubTypes, RetPath); bool CallEmpty = !firstRealType(CallVal->getType(), CallSubTypes, CallPath); @@ -692,7 +692,8 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F, // We've exhausted the values produced by the tail call instruction, the // rest are essentially undef. The type doesn't really matter, but we need // *something*. - Type *SlotType = RetSubTypes.back()->getTypeAtIndex(RetPath.back()); + Type *SlotType = + ExtractValueInst::getIndexedType(RetSubTypes.back(), RetPath.back()); CallVal = UndefValue::get(SlotType); } diff --git a/llvm/lib/FuzzMutate/Operations.cpp b/llvm/lib/FuzzMutate/Operations.cpp index cf55d09caf7e63..43255e16140adb 100644 --- a/llvm/lib/FuzzMutate/Operations.cpp +++ b/llvm/lib/FuzzMutate/Operations.cpp @@ -244,20 +244,24 @@ static SourcePred matchScalarInAggregate() { static SourcePred validInsertValueIndex() { auto Pred = [](ArrayRef Cur, const Value *V) { - auto *CTy = cast(Cur[0]->getType()); if (auto *CI = dyn_cast(V)) - if (CI->getBitWidth() == 32 && - CTy->getTypeAtIndex(CI->getZExtValue()) == Cur[1]->getType()) - return true; + if (CI->getBitWidth() == 32) { + Type *Indexed = ExtractValueInst::getIndexedType(Cur[0]->getType(), + CI->getZExtValue()); + return Indexed == Cur[1]->getType(); + } return false; }; auto Make = [](ArrayRef Cur, ArrayRef Ts) { std::vector Result; auto *Int32Ty = Type::getInt32Ty(Cur[0]->getContext()); - auto *CTy = cast(Cur[0]->getType()); - for (int I = 0, E = getAggregateNumElements(CTy); I < E; ++I) - if (CTy->getTypeAtIndex(I) == Cur[1]->getType()) + auto *BaseTy = Cur[0]->getType(); + int I = 0; + while (Type *Indexed = ExtractValueInst::getIndexedType(BaseTy, I)) { + if (Indexed == Cur[1]->getType()) Result.push_back(ConstantInt::get(Int32Ty, I)); + ++I; + } return Result; }; return {Pred, Make}; diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index dc78c5537befc5..3e2e74c31fc0a6 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -2389,10 +2389,11 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, SmallVector NewIdxs; Type *Ty = PointeeTy; Type *Prev = C->getType(); + auto GEPIter = gep_type_begin(PointeeTy, Idxs); bool Unknown = !isa(Idxs[0]) && !isa(Idxs[0]); for (unsigned i = 1, e = Idxs.size(); i != e; - Prev = Ty, Ty = cast(Ty)->getTypeAtIndex(Idxs[i]), ++i) { + Prev = Ty, Ty = (++GEPIter).getIndexedType(), ++i) { if (!isa(Idxs[i]) && !isa(Idxs[i])) { // We don't know if it's in range or not. Unknown = true; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index c68b461f171992..bde4c07e15a315 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1047,19 +1047,20 @@ static Constant *getSequenceIfElementsMatch(Constant *C, return nullptr; } -ConstantAggregate::ConstantAggregate(CompositeType *T, ValueTy VT, +ConstantAggregate::ConstantAggregate(Type *T, ValueTy VT, ArrayRef V) : Constant(T, VT, OperandTraits::op_end(this) - V.size(), V.size()) { llvm::copy(V, op_begin()); // Check that types match, unless this is an opaque struct. - if (auto *ST = dyn_cast(T)) + if (auto *ST = dyn_cast(T)) { if (ST->isOpaque()) return; - for (unsigned I = 0, E = V.size(); I != E; ++I) - assert(V[I]->getType() == T->getTypeAtIndex(I) && - "Initializer for composite element doesn't match!"); + for (unsigned I = 0, E = V.size(); I != E; ++I) + assert(V[I]->getType() == ST->getTypeAtIndex(I) && + "Initializer for struct element doesn't match!"); + } } ConstantArray::ConstantArray(ArrayType *T, ArrayRef V) diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 68eed612e4bfde..0884a24a709e5d 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -1659,35 +1659,44 @@ GetElementPtrInst::GetElementPtrInst(const GetElementPtrInst &GEPI) SubclassOptionalData = GEPI.SubclassOptionalData; } -/// getIndexedType - Returns the type of the element that would be accessed with -/// a gep instruction with the specified parameters. -/// -/// The Idxs pointer should point to a continuous piece of memory containing the -/// indices, either as Value* or uint64_t. -/// -/// A null type is returned if the indices are invalid for the specified -/// pointer type. -/// -template -static Type *getIndexedTypeInternal(Type *Agg, ArrayRef IdxList) { - // Handle the special case of the empty set index set, which is always valid. - if (IdxList.empty()) - return Agg; - - // If there is at least one index, the top level type must be sized, otherwise - // it cannot be 'stepped over'. - if (!Agg->isSized()) +Type *GetElementPtrInst::getTypeAtIndex(Type *Ty, Value *Idx) { + if (auto Struct = dyn_cast(Ty)) { + if (!Struct->indexValid(Idx)) + return nullptr; + return Struct->getTypeAtIndex(Idx); + } + if (!Idx->getType()->isIntOrIntVectorTy()) return nullptr; + if (auto Array = dyn_cast(Ty)) + return Array->getElementType(); + if (auto Vector = dyn_cast(Ty)) + return Vector->getElementType(); + return nullptr; +} - unsigned CurIdx = 1; - for (; CurIdx != IdxList.size(); ++CurIdx) { - CompositeType *CT = dyn_cast(Agg); - if (!CT || CT->isPointerTy()) return nullptr; - IndexTy Index = IdxList[CurIdx]; - if (!CT->indexValid(Index)) return nullptr; - Agg = CT->getTypeAtIndex(Index); +Type *GetElementPtrInst::getTypeAtIndex(Type *Ty, uint64_t Idx) { + if (auto Struct = dyn_cast(Ty)) { + if (Idx >= Struct->getNumElements()) + return nullptr; + return Struct->getElementType(Idx); } - return CurIdx == IdxList.size() ? Agg : nullptr; + if (auto Array = dyn_cast(Ty)) + return Array->getElementType(); + if (auto Vector = dyn_cast(Ty)) + return Vector->getElementType(); + return nullptr; +} + +template +static Type *getIndexedTypeInternal(Type *Ty, ArrayRef IdxList) { + if (IdxList.empty()) + return Ty; + for (IndexTy V : IdxList.slice(1)) { + Ty = GetElementPtrInst::getTypeAtIndex(Ty, V); + if (!Ty) + return Ty; + } + return Ty; } Type *GetElementPtrInst::getIndexedType(Type *Ty, ArrayRef IdxList) { @@ -2220,15 +2229,15 @@ Type *ExtractValueInst::getIndexedType(Type *Agg, if (ArrayType *AT = dyn_cast(Agg)) { if (Index >= AT->getNumElements()) return nullptr; + Agg = AT->getElementType(); } else if (StructType *ST = dyn_cast(Agg)) { if (Index >= ST->getNumElements()) return nullptr; + Agg = ST->getElementType(Index); } else { // Not a valid type to index into. return nullptr; } - - Agg = cast(Agg)->getTypeAtIndex(Index); } return const_cast(Agg); } diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index 3eab5042b54248..e91bc8aa7e708b 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -529,52 +529,22 @@ StructType *Module::getTypeByName(StringRef Name) const { return getContext().pImpl->NamedStructTypes.lookup(Name); } -//===----------------------------------------------------------------------===// -// CompositeType Implementation -//===----------------------------------------------------------------------===// - -Type *CompositeType::getTypeAtIndex(const Value *V) const { - if (auto *STy = dyn_cast(this)) { - unsigned Idx = - (unsigned)cast(V)->getUniqueInteger().getZExtValue(); - assert(indexValid(Idx) && "Invalid structure index!"); - return STy->getElementType(Idx); - } - - return cast(this)->getElementType(); +Type *StructType::getTypeAtIndex(const Value *V) const { + unsigned Idx = (unsigned)cast(V)->getUniqueInteger().getZExtValue(); + assert(indexValid(Idx) && "Invalid structure index!"); + return getElementType(Idx); } -Type *CompositeType::getTypeAtIndex(unsigned Idx) const{ - if (auto *STy = dyn_cast(this)) { - assert(indexValid(Idx) && "Invalid structure index!"); - return STy->getElementType(Idx); - } - - return cast(this)->getElementType(); -} - -bool CompositeType::indexValid(const Value *V) const { - if (auto *STy = dyn_cast(this)) { - // Structure indexes require (vectors of) 32-bit integer constants. In the - // vector case all of the indices must be equal. - if (!V->getType()->isIntOrIntVectorTy(32)) - return false; - const Constant *C = dyn_cast(V); - if (C && V->getType()->isVectorTy()) - C = C->getSplatValue(); - const ConstantInt *CU = dyn_cast_or_null(C); - return CU && CU->getZExtValue() < STy->getNumElements(); - } - - // Sequential types can be indexed by any integer. - return V->getType()->isIntOrIntVectorTy(); -} - -bool CompositeType::indexValid(unsigned Idx) const { - if (auto *STy = dyn_cast(this)) - return Idx < STy->getNumElements(); - // Sequential types can be indexed by any integer. - return true; +bool StructType::indexValid(const Value *V) const { + // Structure indexes require (vectors of) 32-bit integer constants. In the + // vector case all of the indices must be equal. + if (!V->getType()->isIntOrIntVectorTy(32)) + return false; + const Constant *C = dyn_cast(V); + if (C && V->getType()->isVectorTy()) + C = C->getSplatValue(); + const ConstantInt *CU = dyn_cast_or_null(C); + return CU && CU->getZExtValue() < getNumElements(); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 45a88a686dab6b..2e6bcb550999dd 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -17079,7 +17079,7 @@ bool ARMTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, case Intrinsic::arm_mve_vld4q: { Info.opc = ISD::INTRINSIC_W_CHAIN; // Conservatively set memVT to the entire set of vectors loaded. - Type *VecTy = cast(I.getType())->getTypeAtIndex(1); + Type *VecTy = cast(I.getType())->getElementType(1); unsigned Factor = Intrinsic == Intrinsic::arm_mve_vld2q ? 2 : 4; Info.memVT = EVT::getVectorVT(VecTy->getContext(), MVT::i64, Factor * 2); Info.ptrVal = I.getArgOperand(0); diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 3be31bcd172be5..d718574a81c8ca 100644 --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -295,7 +295,7 @@ doPromotion(Function *F, SmallPtrSetImpl &ArgsToPromote, if (auto *ElPTy = dyn_cast(ElTy)) ElTy = ElPTy->getElementType(); else - ElTy = cast(ElTy)->getTypeAtIndex(II); + ElTy = GetElementPtrInst::getTypeAtIndex(ElTy, II); } // And create a GEP to extract those indices. V = IRB.CreateGEP(ArgIndex.first, V, Ops, V->getName() + ".idx"); @@ -784,7 +784,7 @@ bool ArgumentPromotionPass::isDenselyPacked(Type *type, const DataLayout &DL) { if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) return false; - if (!isa(type)) + if (!isa(type) && !isa(type)) return true; // For homogenous sequential types, check for padding within members. diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp index ff746ad0cbb3a2..3b234ca0be7d36 100644 --- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -142,7 +142,7 @@ static bool isLeakCheckerRoot(GlobalVariable *GV) { E = STy->element_end(); I != E; ++I) { Type *InnerTy = *I; if (isa(InnerTy)) return true; - if (isa(InnerTy)) + if (isa(InnerTy) || isa(InnerTy)) Types.push_back(InnerTy); } break; diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp index 6ce00714523b30..088091df770f9a 100644 --- a/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -147,10 +147,12 @@ static void RemoveDeadConstant(Constant *C) { if (GlobalVariable *GV = dyn_cast(C)) { if (!GV->hasLocalLinkage()) return; // Don't delete non-static globals. GV->eraseFromParent(); - } - else if (!isa(C)) - if (isa(C->getType())) + } else if (!isa(C)) { + // FIXME: Why does the type of the constant matter here? + if (isa(C->getType()) || isa(C->getType()) || + isa(C->getType())) C->destroyConstant(); + } // If the constant referenced anything, see if we can delete it as well. for (Constant *O : Operands) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 0b1d9e8df03925..afdddad10cea22 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2421,10 +2421,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. // This can enhance SROA and other transforms that want type-safe pointers. unsigned NumZeros = 0; - while (SrcElTy != DstElTy && - isa(SrcElTy) && !SrcElTy->isPointerTy() && - SrcElTy->getNumContainedTypes() /* not "{}" */) { - SrcElTy = cast(SrcElTy)->getTypeAtIndex(0U); + while (SrcElTy && SrcElTy != DstElTy) { + SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); ++NumZeros; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index e0063eb24515fa..9d17e92eca203a 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1961,10 +1961,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (J > 0) { if (J == 1) { CurTy = Op1->getSourceElementType(); - } else if (auto *CT = dyn_cast(CurTy)) { - CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { - CurTy = nullptr; + CurTy = + GetElementPtrInst::getTypeAtIndex(CurTy, Op1->getOperand(J)); } } } diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index e411c4ece83d9d..377aa78730b047 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3131,7 +3131,7 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { unsigned N = 1; Type *EltTy = T; - while (isa(EltTy)) { + while (isa(EltTy) || isa(EltTy)) { if (auto *ST = dyn_cast(EltTy)) { // Check that struct is homogeneous. for (const auto *Ty : ST->elements())