diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 2d2ef0c9c4a27..0201942183e68 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -655,6 +655,12 @@ class ScalarEvolution { /// Return a SCEV for the constant 1 of a specific type. const SCEV *getOne(Type *Ty) { return getConstant(Ty, 1); } + /// Return a SCEV for the constant \p Power of two. + const SCEV *getPowerOfTwo(Type *Ty, unsigned Power) { + assert(Power < getTypeSizeInBits(Ty) && "Power out of range"); + return getConstant(APInt::getOneBitSet(getTypeSizeInBits(Ty), Power)); + } + /// Return a SCEV for the constant -1 of a specific type. const SCEV *getMinusOne(Type *Ty) { return getConstant(Ty, -1, /*isSigned=*/true); diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 8756e2c66c25a..985d1cbc642a3 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1744,4 +1744,20 @@ TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) { }); } +TEST_F(ScalarEvolutionsTest, CheckGetPowerOfTwo) { + Module M("CheckGetPowerOfTwo", Context); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", M); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + IRBuilder<> Builder(Entry); + Builder.CreateRetVoid(); + ScalarEvolution SE = buildSE(*F); + + for (unsigned short i = 0; i < 64; ++i) + EXPECT_TRUE( + dyn_cast(SE.getPowerOfTwo(Type::getInt64Ty(Context), i)) + ->getValue() + ->equalsInt(1ULL << i)); +} + } // end namespace llvm