diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index b62d6ec3b1ab6c..26da7fae391b66 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2235,8 +2235,8 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy, "PtrToInt destination must be integer or integer vector"); assert(isa(C->getType()) == isa(DstTy)); if (isa(C->getType())) - assert(cast(C->getType())->getNumElements() == - cast(DstTy)->getNumElements() && + assert(cast(C->getType())->getElementCount() == + cast(DstTy)->getElementCount() && "Invalid cast between a different number of vector elements"); return getFoldedCast(Instruction::PtrToInt, C, DstTy, OnlyIfReduced); } diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp index ad43fbcb0531f5..b764af6cafae15 100644 --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -134,6 +134,9 @@ TEST(ConstantsTest, PointerCast) { VectorType *Int8PtrVecTy = FixedVectorType::get(Int8PtrTy, 4); VectorType *Int32PtrVecTy = FixedVectorType::get(Int32PtrTy, 4); VectorType *Int64VecTy = FixedVectorType::get(Int64Ty, 4); + VectorType *Int8PtrScalableVecTy = ScalableVectorType::get(Int8PtrTy, 4); + VectorType *Int32PtrScalableVecTy = ScalableVectorType::get(Int32PtrTy, 4); + VectorType *Int64ScalableVecTy = ScalableVectorType::get(Int64Ty, 4); // ptrtoint i8* to i64 EXPECT_EQ( @@ -150,11 +153,23 @@ TEST(ConstantsTest, PointerCast) { ConstantExpr::getPointerCast(Constant::getNullValue(Int8PtrVecTy), Int64VecTy)); + // ptrtoint to + EXPECT_EQ( + Constant::getNullValue(Int64ScalableVecTy), + ConstantExpr::getPointerCast(Constant::getNullValue(Int8PtrScalableVecTy), + Int64ScalableVecTy)); + // bitcast <4 x i8*> to <4 x i32*> EXPECT_EQ(Constant::getNullValue(Int32PtrVecTy), ConstantExpr::getPointerCast(Constant::getNullValue(Int8PtrVecTy), Int32PtrVecTy)); + // bitcast to + EXPECT_EQ( + Constant::getNullValue(Int32PtrScalableVecTy), + ConstantExpr::getPointerCast(Constant::getNullValue(Int8PtrScalableVecTy), + Int32PtrScalableVecTy)); + Type *Int32Ptr1Ty = Type::getInt32PtrTy(C, 1); ConstantInt *K = ConstantInt::get(Type::getInt64Ty(C), 1234);