diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h index 41814d5b50fef2..edc59a8be55a21 100644 --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -677,14 +677,17 @@ Type *Type::getExtendedType() const { return cast(this)->getExtendedType(); } +Type *Type::getWithNewType(Type *EltTy) const { + if (auto *VTy = dyn_cast(this)) + return VectorType::get(EltTy, VTy->getElementCount()); + return EltTy; +} + Type *Type::getWithNewBitWidth(unsigned NewBitWidth) const { assert( isIntOrIntVectorTy() && "Original type expected to be a vector of integers or a scalar integer."); - Type *NewType = getIntNTy(getContext(), NewBitWidth); - if (auto *VTy = dyn_cast(this)) - NewType = VectorType::get(NewType, VTy->getElementCount()); - return NewType; + return getWithNewType(getIntNTy(getContext(), NewBitWidth)); } unsigned Type::getPointerAddressSpace() const { diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h index 2b1f054b111860..e30d4ad2261394 100644 --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -380,6 +380,11 @@ class Type { return ContainedTys[0]; } + /// Given vector type, change the element type, + /// whilst keeping the old number of elements. + /// For non-vectors simply returns \p EltTy. + inline Type *getWithNewType(Type *EltTy) const; + /// Given an integer or vector type, change the lane bitwidth to NewBitwidth, /// whilst keeping the old number of lanes. inline Type *getWithNewBitWidth(unsigned NewBitWidth) const; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index ae72123e3f0011..75621da20a5ddc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1927,11 +1927,8 @@ Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { unsigned AS = CI.getAddressSpace(); if (CI.getOperand(0)->getType()->getScalarSizeInBits() != DL.getPointerSizeInBits(AS)) { - Type *Ty = DL.getIntPtrType(CI.getContext(), AS); - // Handle vectors of pointers. - if (auto *CIVTy = dyn_cast(CI.getType())) - Ty = VectorType::get(Ty, CIVTy->getElementCount()); - + Type *Ty = CI.getOperand(0)->getType()->getWithNewType( + DL.getIntPtrType(CI.getContext(), AS)); Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); return new IntToPtrInst(P, CI.getType()); } @@ -1970,16 +1967,14 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast // to be exposed to other transforms. Value *SrcOp = CI.getPointerOperand(); + Type *SrcTy = SrcOp->getType(); Type *Ty = CI.getType(); unsigned AS = CI.getPointerAddressSpace(); unsigned TySize = Ty->getScalarSizeInBits(); unsigned PtrSize = DL.getPointerSizeInBits(AS); if (TySize != PtrSize) { - Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS); - // Handle vectors of pointers. - if (auto *VecTy = dyn_cast(Ty)) - IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount()); - + Type *IntPtrTy = + SrcTy->getWithNewType(DL.getIntPtrType(CI.getContext(), AS)); Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy); return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); }