Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc][NFC] Simplify BigInt #81992

Merged
merged 1 commit into from
Feb 16, 2024
Merged

[libc][NFC] Simplify BigInt #81992

merged 1 commit into from
Feb 16, 2024

Conversation

gchatelet
Copy link
Contributor

@gchatelet gchatelet commented Feb 16, 2024

  • Add a single cmp function to derive all comparison operators
  • Use the friend version of the member functions for symmetry
  • Add a is_neg function to factor sign extraction
  • Implement binary op through macro expansion

@llvmbot llvmbot added the libc label Feb 16, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-libc

Author: Guillaume Chatelet (gchatelet)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/81992.diff

2 Files Affected:

  • (modified) libc/src/__support/CPP/array.h (+2-2)
  • (modified) libc/src/__support/UInt.h (+72-144)
diff --git a/libc/src/__support/CPP/array.h b/libc/src/__support/CPP/array.h
index 1897066514092c..fb5a79225beb7d 100644
--- a/libc/src/__support/CPP/array.h
+++ b/libc/src/__support/CPP/array.h
@@ -28,10 +28,10 @@ template <class T, size_t N> struct array {
   LIBC_INLINE constexpr const T *data() const { return Data; }
 
   LIBC_INLINE constexpr T &front() { return Data[0]; }
-  LIBC_INLINE constexpr T &front() const { return Data[0]; }
+  LIBC_INLINE constexpr const T &front() const { return Data[0]; }
 
   LIBC_INLINE constexpr T &back() { return Data[N - 1]; }
-  LIBC_INLINE constexpr T &back() const { return Data[N - 1]; }
+  LIBC_INLINE constexpr const T &back() const { return Data[N - 1]; }
 
   LIBC_INLINE constexpr T &operator[](size_t Index) { return Data[Index]; }
 
diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index b90275035a23ea..f2084300d14208 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -68,8 +68,8 @@ struct BigInt {
         val[i] = other[i];
       WordType sign = 0;
       if constexpr (Signed && OtherSigned) {
-        sign = static_cast<WordType>(-static_cast<make_signed_t<WordType>>(
-            other[OtherBits / WORD_SIZE - 1] >> (WORD_SIZE - 1)));
+        sign = static_cast<WordType>(
+            -static_cast<make_signed_t<WordType>>(other.is_neg()));
       }
       for (; i < WORD_COUNT; ++i)
         val[i] = sign;
@@ -125,6 +125,11 @@ struct BigInt {
       val[i] = words[i];
   }
 
+  // TODO: Reuse the Sign type.
+  LIBC_INLINE constexpr bool is_neg() const {
+    return val.back() >> (WORD_SIZE - 1);
+  }
+
   template <typename T> LIBC_INLINE constexpr explicit operator T() const {
     return to<T>();
   }
@@ -148,7 +153,7 @@ struct BigInt {
     if constexpr (Signed && (T_BITS > Bits)) {
       // Extend sign for negative numbers.
       constexpr T MASK = (~T(0) << Bits);
-      if (val[WORD_COUNT - 1] >> (WORD_SIZE - 1))
+      if (is_neg())
         lo |= MASK;
     }
 
@@ -267,8 +272,8 @@ struct BigInt {
     if constexpr (Signed) {
       BigInt<Bits, false, WordType> a(*this);
       BigInt<Bits, false, WordType> b(other);
-      bool a_neg = (a.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
-      bool b_neg = (b.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
+      const bool a_neg = a.is_neg();
+      const bool b_neg = b.is_neg();
       if (a_neg)
         a = -a;
       if (b_neg)
@@ -278,7 +283,6 @@ struct BigInt {
         prod = -prod;
       return static_cast<BigInt<Bits, true, WordType>>(prod);
     } else {
-
       if constexpr (WORD_COUNT == 1) {
         return {val[0] * other.val[0]};
       } else {
@@ -383,10 +387,9 @@ struct BigInt {
     BigInt cur_power = *this;
 
     while (power > 0) {
-      if ((power % 2) > 0) {
-        result = result * cur_power;
-      }
-      power = power >> 1;
+      if ((power % 2) > 0)
+        result *= cur_power;
+      power >>= 1;
       cur_power *= cur_power;
     }
     *this = result;
@@ -709,7 +712,7 @@ struct BigInt {
     const size_t shift = s % WORD_SIZE; // Bit shift in the remaining words.
 
     size_t i = 0;
-    WordType sign = Signed ? (val[WORD_COUNT - 1] >> (WORD_SIZE - 1)) : 0;
+    WordType sign = Signed ? is_neg() : 0;
 
     if (drop < WORD_COUNT) {
       if (shift > 0) {
@@ -747,49 +750,31 @@ struct BigInt {
     return *this;
   }
 
-  LIBC_INLINE constexpr BigInt operator&(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] & other.val[i];
-    return result;
+#define DEFINE_BINOP(OP)                                                       \
+  LIBC_INLINE friend constexpr BigInt operator OP(const BigInt &lhs,           \
+                                                  const BigInt &rhs) {         \
+    BigInt result;                                                             \
+    for (size_t i = 0; i < WORD_COUNT; ++i)                                    \
+      result[i] = lhs[i] OP rhs[i];                                            \
+    return result;                                                             \
+  }                                                                            \
+  LIBC_INLINE friend constexpr BigInt operator OP##=(BigInt &lhs,              \
+                                                     const BigInt &rhs) {      \
+    for (size_t i = 0; i < WORD_COUNT; ++i)                                    \
+      lhs[i] OP## = rhs[i];                                                    \
+    return lhs;                                                                \
   }
 
-  LIBC_INLINE constexpr BigInt &operator&=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] &= other.val[i];
-    return *this;
-  }
+  DEFINE_BINOP(&)
+  DEFINE_BINOP(|)
+  DEFINE_BINOP(^)
 
-  LIBC_INLINE constexpr BigInt operator|(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] | other.val[i];
-    return result;
-  }
-
-  LIBC_INLINE constexpr BigInt &operator|=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] |= other.val[i];
-    return *this;
-  }
-
-  LIBC_INLINE constexpr BigInt operator^(const BigInt &other) const {
-    BigInt result;
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = val[i] ^ other.val[i];
-    return result;
-  }
-
-  LIBC_INLINE constexpr BigInt &operator^=(const BigInt &other) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] ^= other.val[i];
-    return *this;
-  }
+#undef DEFINE_BINOP
 
   LIBC_INLINE constexpr BigInt operator~() const {
     BigInt result;
     for (size_t i = 0; i < WORD_COUNT; ++i)
-      result.val[i] = ~val[i];
+      result[i] = ~val[i];
     return result;
   }
 
@@ -799,139 +784,82 @@ struct BigInt {
     return result;
   }
 
-  LIBC_INLINE constexpr bool operator==(const BigInt &other) const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != other.val[i])
+  LIBC_INLINE friend constexpr bool operator==(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
+      if (lhs.val[i] != rhs.val[i])
         return false;
-    }
     return true;
   }
 
-  LIBC_INLINE constexpr bool operator!=(const BigInt &other) const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != other.val[i])
-        return true;
-    }
-    return false;
+  LIBC_INLINE friend constexpr bool operator!=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return !(lhs == rhs);
   }
 
-  LIBC_INLINE constexpr bool operator>(const BigInt &other) const {
+private:
+  LIBC_INLINE friend constexpr int cmp(const BigInt &lhs, const BigInt &rhs) {
+    const auto compare = [](WordType a, WordType b) {
+      return a == b ? 0 : a > b ? 1 : -1;
+    };
     if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return b_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return true;
-      else if (word < other_word)
-        return false;
+      const bool lhs_is_neg = lhs.is_neg();
+      const bool rhs_is_neg = rhs.is_neg();
+      if (lhs_is_neg != rhs_is_neg)
+        return rhs_is_neg ? 1 : -1;
     }
-    // Equal
-    return false;
+    for (size_t i = WORD_COUNT; i-- > 0;)
+      if (auto cmp = compare(lhs[i], rhs[i]); cmp != 0)
+        return cmp;
+    return 0;
   }
 
-  LIBC_INLINE constexpr bool operator>=(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return b_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return true;
-      else if (word < other_word)
-        return false;
-    }
-    // Equal
-    return true;
+public:
+  LIBC_INLINE friend constexpr bool operator>(const BigInt &lhs,
+                                              const BigInt &rhs) {
+    return cmp(lhs, rhs) > 0;
   }
-
-  LIBC_INLINE constexpr bool operator<(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return a_sign;
-      }
-    }
-
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return false;
-      else if (word < other_word)
-        return true;
-    }
-    // Equal
-    return false;
+  LIBC_INLINE friend constexpr bool operator>=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return cmp(lhs, rhs) >= 0;
   }
-
-  LIBC_INLINE constexpr bool operator<=(const BigInt &other) const {
-    if constexpr (Signed) {
-      // Check for different signs;
-      bool a_sign = val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      bool b_sign = other.val[WORD_COUNT - 1] >> (WORD_SIZE - 1);
-      if (a_sign != b_sign) {
-        return a_sign;
-      }
-    }
-    for (size_t i = WORD_COUNT; i > 0; --i) {
-      WordType word = val[i - 1];
-      WordType other_word = other.val[i - 1];
-      if (word > other_word)
-        return false;
-      else if (word < other_word)
-        return true;
-    }
-    // Equal
-    return true;
+  LIBC_INLINE friend constexpr bool operator<(const BigInt &lhs,
+                                              const BigInt &rhs) {
+    return cmp(lhs, rhs) < 0;
+  }
+  LIBC_INLINE friend constexpr bool operator<=(const BigInt &lhs,
+                                               const BigInt &rhs) {
+    return cmp(lhs, rhs) <= 0;
   }
 
   LIBC_INLINE constexpr BigInt &operator++() {
-    BigInt one(1);
-    add(one);
+    add(BigInt(1));
     return *this;
   }
 
   LIBC_INLINE constexpr BigInt operator++(int) {
     BigInt oldval(*this);
-    BigInt one(1);
-    add(one);
+    add(BigInt(1));
     return oldval;
   }
 
   LIBC_INLINE constexpr BigInt &operator--() {
-    BigInt one(1);
-    sub(one);
+    sub(BigInt(1));
     return *this;
   }
 
   LIBC_INLINE constexpr BigInt operator--(int) {
     BigInt oldval(*this);
-    BigInt one(1);
-    sub(one);
+    sub(BigInt(1));
     return oldval;
   }
 
-  // Return the i-th 64-bit word of the number.
+  // Return the i-th word of the number.
   LIBC_INLINE constexpr const WordType &operator[](size_t i) const {
     return val[i];
   }
 
-  // Return the i-th 64-bit word of the number.
+  // Return the i-th word of the number.
   LIBC_INLINE constexpr WordType &operator[](size_t i) { return val[i]; }
 
   LIBC_INLINE WordType *data() { return val; }

@gchatelet gchatelet requested a review from lntue February 16, 2024 13:47
@gchatelet gchatelet merged commit 2b677fa into llvm:main Feb 16, 2024
6 checks passed
@gchatelet gchatelet deleted the simplify_bigint branch February 16, 2024 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants