diff --git a/llvm/include/llvm/ADT/bit.h b/llvm/include/llvm/ADT/bit.h index d6e33c3e6133a..2649f08743ef9 100644 --- a/llvm/include/llvm/ADT/bit.h +++ b/llvm/include/llvm/ADT/bit.h @@ -154,6 +154,27 @@ template >> /// Only unsigned integral types are allowed. /// /// Returns std::numeric_limits::digits on an input of 0. +template [[nodiscard]] constexpr int countr_zero_constexpr(T Val) { + static_assert(std::is_unsigned_v, + "Only unsigned integral types are allowed."); + if (!Val) + return std::numeric_limits::digits; + + // Use the bisection method. + unsigned ZeroBits = 0; + T Shift = std::numeric_limits::digits >> 1; + T Mask = std::numeric_limits::max() >> Shift; + while (Shift) { + if ((Val & Mask) == 0) { + Val >>= Shift; + ZeroBits |= Shift; + } + Shift >>= 1; + Mask >>= Shift; + } + return ZeroBits; +} + template [[nodiscard]] int countr_zero(T Val) { static_assert(std::is_unsigned_v, "Only unsigned integral types are allowed."); @@ -179,19 +200,8 @@ template [[nodiscard]] int countr_zero(T Val) { #endif } - // Fall back to the bisection method. - unsigned ZeroBits = 0; - T Shift = std::numeric_limits::digits >> 1; - T Mask = std::numeric_limits::max() >> Shift; - while (Shift) { - if ((Val & Mask) == 0) { - Val >>= Shift; - ZeroBits |= Shift; - } - Shift >>= 1; - Mask >>= Shift; - } - return ZeroBits; + // Fall back to the constexpr implementation. + return countr_zero_constexpr(Val); } /// Count number of 0's from the most significant bit to the least diff --git a/llvm/unittests/ADT/BitTest.cpp b/llvm/unittests/ADT/BitTest.cpp index 2377ce3b78261..bc441df89f5ad 100644 --- a/llvm/unittests/ADT/BitTest.cpp +++ b/llvm/unittests/ADT/BitTest.cpp @@ -279,6 +279,26 @@ TEST(BitTest, CountlZero) { } } +TEST(BitTest, CountrZeroConstexpr) { + constexpr uint8_t Z8 = 0; + constexpr uint16_t Z16 = 0; + constexpr uint32_t Z32 = 0; + constexpr uint64_t Z64 = 0; + static_assert(llvm::countr_zero_constexpr(Z8) == 8, ""); + static_assert(llvm::countr_zero_constexpr(Z16) == 16, ""); + static_assert(llvm::countr_zero_constexpr(Z32) == 32, ""); + static_assert(llvm::countr_zero_constexpr(Z64) == 64, ""); + + constexpr uint8_t NZ8 = 42; + constexpr uint16_t NZ16 = 42; + constexpr uint32_t NZ32 = 42; + constexpr uint64_t NZ64 = 42; + static_assert(llvm::countr_zero_constexpr(NZ8) == 1, ""); + static_assert(llvm::countr_zero_constexpr(NZ16) == 1, ""); + static_assert(llvm::countr_zero_constexpr(NZ32) == 1, ""); + static_assert(llvm::countr_zero_constexpr(NZ64) == 1, ""); +} + TEST(BitTest, CountrZero) { uint8_t Z8 = 0; uint16_t Z16 = 0;