Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] Add support for zero-width integers in MathExtras.h #87193

Merged
merged 3 commits into from
Apr 22, 2024

Conversation

Moxinilian
Copy link
Member

MLIR uses zero-width integers, but also re-uses integer logic from LLVM to avoid duplication. This creates issues when LLVM logic is used in MLIR on integers which can be zero-width. In order to avoid special-casing the bitwidth-related logic in MLIR, this PR adds support for zero-width integers in LLVM's MathExtras (and consequently APInt).

While most of the logic in theory works the same way out of the box, because bitshifting right by the entire bitwidth in C++ is undefined behavior instead of being zero, some special cases had to be added. Fortunately, it seems like the performance penalty is small. In x86, this usually yields the addition of a predicated conditional move. I checked that no branch is inserted in Arm too.

This happens to fix a crash in arith.extsi canonicalization in MLIR. I think a follow-up PR to add tests for i0 in arith would be beneficial.

@Moxinilian Moxinilian added llvm Umbrella label for LLVM issues llvm:adt labels Mar 31, 2024
@Moxinilian Moxinilian self-assigned this Mar 31, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 31, 2024

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-adt

Author: Théo Degioanni (Moxinilian)

Changes

MLIR uses zero-width integers, but also re-uses integer logic from LLVM to avoid duplication. This creates issues when LLVM logic is used in MLIR on integers which can be zero-width. In order to avoid special-casing the bitwidth-related logic in MLIR, this PR adds support for zero-width integers in LLVM's MathExtras (and consequently APInt).

While most of the logic in theory works the same way out of the box, because bitshifting right by the entire bitwidth in C++ is undefined behavior instead of being zero, some special cases had to be added. Fortunately, it seems like the performance penalty is small. In x86, this usually yields the addition of a predicated conditional move. I checked that no branch is inserted in Arm too.

This happens to fix a crash in arith.extsi canonicalization in MLIR. I think a follow-up PR to add tests for i0 in arith would be beneficial.


Full diff: https://github.com/llvm/llvm-project/pull/87193.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Support/MathExtras.h (+28-25)
  • (modified) llvm/unittests/ADT/APIntTest.cpp (+3)
  • (modified) llvm/unittests/Support/MathExtrasTest.cpp (+8)
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 7278c4eb769588..5ac91528008c12 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -94,7 +94,7 @@ static const unsigned char BitReverseTable256[256] = {
 #define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64
 #define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16)
 #define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4)
-  R6(0), R6(2), R6(1), R6(3)
+    R6(0), R6(2), R6(1), R6(3)
 #undef R2
 #undef R4
 #undef R6
@@ -149,6 +149,8 @@ constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) {
 
 /// Checks if an integer fits into the given bit width.
 template <unsigned N> constexpr inline bool isInt(int64_t x) {
+  if constexpr (N == 0)
+    return 0 == x;
   if constexpr (N == 8)
     return static_cast<int8_t>(x) == x;
   if constexpr (N == 16)
@@ -164,15 +166,15 @@ template <unsigned N> constexpr inline bool isInt(int64_t x) {
 /// Checks if a signed integer is an N bit number shifted left by S.
 template <unsigned N, unsigned S>
 constexpr inline bool isShiftedInt(int64_t x) {
-  static_assert(
-      N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number.");
+  static_assert(S < 64, "isShiftedInt<N, S> with S >= 64 is too much.");
   static_assert(N + S <= 64, "isShiftedInt<N, S> with N + S > 64 is too wide.");
   return isInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
 }
 
 /// Checks if an unsigned integer fits into the given bit width.
 template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
-  static_assert(N > 0, "isUInt<0> doesn't make sense");
+  if constexpr (N == 0)
+    return 0 == x;
   if constexpr (N == 8)
     return static_cast<uint8_t>(x) == x;
   if constexpr (N == 16)
@@ -188,8 +190,7 @@ template <unsigned N> constexpr inline bool isUInt(uint64_t x) {
 /// Checks if a unsigned integer is an N bit number shifted left by S.
 template <unsigned N, unsigned S>
 constexpr inline bool isShiftedUInt(uint64_t x) {
-  static_assert(
-      N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)");
+  static_assert(S < 64, "isShiftedUInt<N, S> with S >= 64 is too much.");
   static_assert(N + S <= 64,
                 "isShiftedUInt<N, S> with N + S > 64 is too wide.");
   // Per the two static_asserts above, S must be strictly less than 64.  So
@@ -199,29 +200,32 @@ constexpr inline bool isShiftedUInt(uint64_t x) {
 
 /// Gets the maximum value for a N-bit unsigned integer.
 inline uint64_t maxUIntN(uint64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
   // uint64_t(1) << 64 is undefined behavior, so we can't do
   //   (uint64_t(1) << N) - 1
   // without checking first that N != 64.  But this works and doesn't have a
-  // branch.
-  return UINT64_MAX >> (64 - N);
+  // branch for N != 0.
+  // Unfortunately, shifting a uint64_t right by 64 bit is undefined
+  // behavior, so the condition on N == 0 is necessary. Fortunately, most
+  // optimizers do not emit branches for this check.
+  return N == 0 ? 0 : UINT64_MAX >> (64 - N);
 }
 
 /// Gets the minimum value for a N-bit signed integer.
 inline int64_t minIntN(int64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
-  return UINT64_C(1) + ~(UINT64_C(1) << (N - 1));
+  return N == 0 ? 0 : UINT64_C(1) + ~(UINT64_C(1) << (N - 1));
 }
 
 /// Gets the maximum value for a N-bit signed integer.
 inline int64_t maxIntN(int64_t N) {
-  assert(N > 0 && N <= 64 && "integer width out of range");
+  assert(N <= 64 && "integer width out of range");
 
   // This relies on two's complement wraparound when N == 64, so we convert to
   // int64_t only at the very end to avoid UB.
-  return (UINT64_C(1) << (N - 1)) - 1;
+  return N == 0 ? 0 : (UINT64_C(1) << (N - 1)) - 1;
 }
 
 /// Checks if an unsigned integer fits into the given (dynamic) bit width.
@@ -432,35 +436,35 @@ inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
-/// Requires 0 < B <= 32.
+/// Requires B <= 32.
 template <unsigned B> constexpr inline int32_t SignExtend32(uint32_t X) {
-  static_assert(B > 0, "Bit width can't be 0.");
   static_assert(B <= 32, "Bit width out of range.");
+  if constexpr (B == 0)
+    return 0;
   return int32_t(X << (32 - B)) >> (32 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
-/// Requires 0 < B <= 32.
+/// Requires B <= 32.
 inline int32_t SignExtend32(uint32_t X, unsigned B) {
-  assert(B > 0 && "Bit width can't be 0.");
   assert(B <= 32 && "Bit width out of range.");
-  return int32_t(X << (32 - B)) >> (32 - B);
+  return B == 0 ? 0 : int32_t(X << (32 - B)) >> (32 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
-/// Requires 0 < B <= 64.
+/// Requires B <= 64.
 template <unsigned B> constexpr inline int64_t SignExtend64(uint64_t x) {
-  static_assert(B > 0, "Bit width can't be 0.");
   static_assert(B <= 64, "Bit width out of range.");
+  if constexpr (B == 0)
+    return 0;
   return int64_t(x << (64 - B)) >> (64 - B);
 }
 
 /// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
-/// Requires 0 < B <= 64.
+/// Requires B <= 64.
 inline int64_t SignExtend64(uint64_t X, unsigned B) {
-  assert(B > 0 && "Bit width can't be 0.");
   assert(B <= 64 && "Bit width out of range.");
-  return int64_t(X << (64 - B)) >> (64 - B);
+  return B == 0 ? 0 : int64_t(X << (64 - B)) >> (64 - B);
 }
 
 /// Subtract two unsigned integers, X and Y, of type T and return the absolute
@@ -564,7 +568,6 @@ SaturatingMultiplyAdd(T X, T Y, T A, bool *ResultOverflowed = nullptr) {
 /// Use this rather than HUGE_VALF; the latter causes warnings on MSVC.
 extern const float huge_valf;
 
-
 /// Add two signed integers, computing the two's complement truncated result,
 /// returning true if overflow occurred.
 template <typename T>
@@ -644,6 +647,6 @@ std::enable_if_t<std::is_signed_v<T>, T> MulOverflow(T X, T Y, T &Result) {
     return UX > (static_cast<U>(std::numeric_limits<T>::max())) / UY;
 }
 
-} // End llvm namespace
+} // namespace llvm
 
 #endif
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 3b909f8f7d14e2..f14e4607c6deb1 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2692,6 +2692,9 @@ TEST(APIntTest, sext) {
   EXPECT_EQ(63U, i32_neg1.countl_one());
   EXPECT_EQ(0U, i32_neg1.countr_zero());
   EXPECT_EQ(63U, i32_neg1.popcount());
+
+  EXPECT_EQ(APInt(32u, 0), APInt(0u, 0).sext(32));
+  EXPECT_EQ(APInt(64u, 0), APInt(0u, 0).sext(64));
 }
 
 TEST(APIntTest, trunc) {
diff --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index 72c765d9ba3004..218655851ca0d7 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -41,6 +41,9 @@ TEST(MathExtras, onesMask) {
 TEST(MathExtras, isIntN) {
   EXPECT_TRUE(isIntN(16, 32767));
   EXPECT_FALSE(isIntN(16, 32768));
+  EXPECT_TRUE(isUIntN(0, 0));
+  EXPECT_FALSE(isUIntN(0, 1));
+  EXPECT_FALSE(isUIntN(0, -1));
 }
 
 TEST(MathExtras, isUIntN) {
@@ -48,6 +51,8 @@ TEST(MathExtras, isUIntN) {
   EXPECT_FALSE(isUIntN(16, 65536));
   EXPECT_TRUE(isUIntN(1, 0));
   EXPECT_TRUE(isUIntN(6, 63));
+  EXPECT_TRUE(isUIntN(0, 0));
+  EXPECT_FALSE(isUIntN(0, 1));
 }
 
 TEST(MathExtras, maxIntN) {
@@ -55,6 +60,7 @@ TEST(MathExtras, maxIntN) {
   EXPECT_EQ(2147483647, maxIntN(32));
   EXPECT_EQ(std::numeric_limits<int32_t>::max(), maxIntN(32));
   EXPECT_EQ(std::numeric_limits<int64_t>::max(), maxIntN(64));
+  EXPECT_EQ(0, maxIntN(0));
 }
 
 TEST(MathExtras, minIntN) {
@@ -62,6 +68,7 @@ TEST(MathExtras, minIntN) {
   EXPECT_EQ(-64LL, minIntN(7));
   EXPECT_EQ(std::numeric_limits<int32_t>::min(), minIntN(32));
   EXPECT_EQ(std::numeric_limits<int64_t>::min(), minIntN(64));
+  EXPECT_EQ(0, minIntN(0));
 }
 
 TEST(MathExtras, maxUIntN) {
@@ -70,6 +77,7 @@ TEST(MathExtras, maxUIntN) {
   EXPECT_EQ(0xffffffffffffffffULL, maxUIntN(64));
   EXPECT_EQ(1ULL, maxUIntN(1));
   EXPECT_EQ(0x0fULL, maxUIntN(4));
+  EXPECT_EQ(0, maxUIntN(0));
 }
 
 TEST(MathExtras, reverseBits) {

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@Moxinilian Moxinilian removed their assignment Mar 31, 2024
@makslevental
Copy link
Contributor

Just chiming in to comment there's a downstream need for support for zero width integers in CIRCT that goes back ~2 years llvm/circt#2660. So 👍 from me.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code generally looks good to me, although I'd personally replace many of N == 0 ? x : y with

if (N == 0)
  return x;
return y;

this is imo more readable especially when y is a relatively large expression.

Regarding the state of 0-bit support, this PR only handles sext right? Is everything else already supported and tested? Would it be more accurate to say in the PR title and description that this only adds support for sext of 0-bit integers?

@Moxinilian
Copy link
Member Author

Moxinilian commented Apr 8, 2024

I have walked through the entire file to check that i0 works properly. But it’s probably a good idea to add tests even where it works to make sure there is no regression. I can look at that, probably after EuroLLVM.

Actually I remember that in many places the code is not really related to integer width, so there may be nothing to add. I’ll ping you either way.

@Moxinilian
Copy link
Member Author

@zero9178 So it turns out that indeed everything else is either not dependent on bitwidth or tested for 0. One of the functions in fact already accomodated for the undefined behavior of shifting a 64-bit value by 64 in the zero case. I updated it to your recommended syntax for consistency as I agree it makes it easier to parse.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thank you!

@Moxinilian Moxinilian merged commit a54102a into llvm:main Apr 22, 2024
4 checks passed
@Moxinilian Moxinilian deleted the fix-i0 branch April 22, 2024 18:54
Moxinilian added a commit that referenced this pull request Apr 28, 2024
Before #87193, the canonicalizer in arith crashed when attempting signed
extension on an i0 value. To hopefully avoid it happening again, this PR
introduces tests for canonicalization of arith operations with i0
values, focusing on operations related to bit width or signedness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:adt llvm:support llvm Umbrella label for LLVM issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants