Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc] Allow BigInt class to use base word types other than uint64_t. #81634

Merged
merged 2 commits into from
Feb 13, 2024

Conversation

lntue
Copy link
Contributor

@lntue lntue commented Feb 13, 2024

This will allow DyadicFloat class to replace NormalFloat class.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-libc

Author: None (lntue)

Changes

This will allow DyadicFloat class to replace NormalFloat class.


Patch is 54.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81634.diff

4 Files Affected:

  • (modified) libc/src/__support/FPUtil/dyadic_float.h (+3-3)
  • (modified) libc/src/__support/UInt.h (+389-344)
  • (modified) libc/src/__support/float_to_string.h (+12-10)
  • (modified) libc/test/src/__support/uint_test.cpp (+16-1)
diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 888d7ffec241ea..a8b3ad7a16d3bb 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -216,7 +216,7 @@ constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
     if (result.mantissa.add(b.mantissa)) {
       // Mantissa addition overflow.
       result.shift_right(1);
-      result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] |=
+      result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
           (uint64_t(1) << 63);
     }
     // Result is already normalized.
@@ -243,7 +243,7 @@ constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
 //   result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
 //                   ~ (full product a.mantissa * b.mantissa) >> Bits.
 // The errors compared to the mathematical product is bounded by:
-//   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORDCOUNT - 1) in ULPs.
+//   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
 // Assume inputs are normalized (by constructors or other functions) so that we
 // don't need to normalize the inputs again in this function.  If the inputs are
 // not normalized, the results might lose precision significantly.
@@ -258,7 +258,7 @@ constexpr DyadicFloat<Bits> quick_mul(DyadicFloat<Bits> a,
     result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
     // Check the leading bit directly, should be faster than using clz in
     // normalize().
-    if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORDCOUNT - 1] >>
+    if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
             63 ==
         0)
       result.shift_left(1);
diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index 7726b6d88f0d21..5a60ea0e6d8135 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -25,35 +25,30 @@
 
 namespace LIBC_NAMESPACE::cpp {
 
-template <size_t Bits, bool Signed> struct BigInt {
+template <size_t Bits, bool Signed, typename WordType = uint64_t>
+struct BigInt {
+  static_assert(is_integral_v<WordType> && is_unsigned_v<WordType>,
+                "WordType must be unsigned integer.");
 
-  // This being hardcoded as 64 is okay because we're using uint64_t as our
-  // internal type which will always be 64 bits.
-  using word_type = uint64_t;
-  LIBC_INLINE_VAR static constexpr size_t WORD_SIZE =
-      sizeof(word_type) * CHAR_BIT;
+  LIBC_INLINE_VAR
+  static constexpr size_t WORD_SIZE = sizeof(WordType) * CHAR_BIT;
 
-  // TODO: Replace references to 64 with WORD_SIZE, and uint64_t with word_type.
-  static_assert(Bits > 0 && Bits % 64 == 0,
-                "Number of bits in BigInt should be a multiple of 64.");
-  LIBC_INLINE_VAR static constexpr size_t WORDCOUNT = Bits / 64;
-  cpp::array<word_type, WORDCOUNT> val{};
+  static_assert(Bits > 0 && Bits % WORD_SIZE == 0,
+                "Number of bits in BigInt should be a multiple of WORD_SIZE.");
 
-  LIBC_INLINE_VAR static constexpr uint64_t MASK32 = 0xFFFFFFFFu;
-
-  LIBC_INLINE static constexpr uint64_t low(uint64_t v) { return v & MASK32; }
-  LIBC_INLINE static constexpr uint64_t high(uint64_t v) {
-    return (v >> 32) & MASK32;
-  }
+  LIBC_INLINE_VAR static constexpr size_t WORD_COUNT = Bits / WORD_SIZE;
+  cpp::array<WordType, WORD_COUNT> val{};
 
   LIBC_INLINE constexpr BigInt() = default;
 
-  LIBC_INLINE constexpr BigInt(const BigInt<Bits, Signed> &other) = default;
+  LIBC_INLINE constexpr BigInt(const BigInt<Bits, Signed, WordType> &other) =
+      default;
 
   template <size_t OtherBits, bool OtherSigned>
-  LIBC_INLINE constexpr BigInt(const BigInt<OtherBits, OtherSigned> &other) {
+  LIBC_INLINE constexpr BigInt(
+      const BigInt<OtherBits, OtherSigned, WordType> &other) {
     if (OtherBits >= Bits) {
-      for (size_t i = 0; i < WORDCOUNT; ++i)
+      for (size_t i = 0; i < WORD_COUNT; ++i)
         val[i] = other[i];
     } else {
       size_t i = 0;
@@ -64,49 +59,57 @@ template <size_t Bits, bool Signed> struct BigInt {
         sign = static_cast<uint64_t>(
             -static_cast<int64_t>(other[OtherBits / 64 - 1] >> 63));
       }
-      for (; i < WORDCOUNT; ++i)
+      for (; i < WORD_COUNT; ++i)
         val[i] = sign;
     }
   }
 
   // Construct a BigInt from a C array.
-  template <size_t N, enable_if_t<N <= WORDCOUNT, int> = 0>
-  LIBC_INLINE constexpr BigInt(const uint64_t (&nums)[N]) {
-    size_t min_wordcount = N < WORDCOUNT ? N : WORDCOUNT;
+  template <size_t N, enable_if_t<N <= WORD_COUNT, int> = 0>
+  LIBC_INLINE constexpr BigInt(const WordType (&nums)[N]) {
+    size_t min_wordcount = N < WORD_COUNT ? N : WORD_COUNT;
     size_t i = 0;
     for (; i < min_wordcount; ++i)
       val[i] = nums[i];
 
     // If nums doesn't completely fill val, then fill the rest with zeroes.
-    for (; i < WORDCOUNT; ++i)
+    for (; i < WORD_COUNT; ++i)
       val[i] = 0;
   }
 
   // Initialize the first word to |v| and the rest to 0.
-  template <typename T,
-            typename = cpp::enable_if_t<is_integral_v<T> && sizeof(T) <= 16>>
+  template <typename T, typename = cpp::enable_if_t<is_integral_v<T>>>
   LIBC_INLINE constexpr BigInt(T v) {
-    val[0] = static_cast<uint64_t>(v);
+    val[0] = static_cast<WordType>(v);
 
-    if constexpr (Bits == 64)
+    if constexpr (WORD_COUNT == 1)
       return;
 
-    // Bits is at least 128.
-    size_t i = 1;
-    if constexpr (sizeof(T) == 16) {
-      val[1] = static_cast<uint64_t>(v >> 64);
-      i = 2;
+    if constexpr (Bits < sizeof(T) * CHAR_BIT) {
+      for (int i = 1; i < WORD_COUNT; ++i) {
+        v >>= WORD_SIZE;
+        val[i] = static_cast<WordType>(v);
+      }
+      return;
     }
 
-    uint64_t sign = (Signed && (v < 0)) ? 0xffff'ffff'ffff'ffff : 0;
-    for (; i < WORDCOUNT; ++i) {
+    size_t i = 1;
+
+    if constexpr (WORD_SIZE < sizeof(T) * CHAR_BIT)
+      for (; i < sizeof(T) * CHAR_BIT / WORD_SIZE; ++i) {
+        v >>= WORD_SIZE;
+        val[i] = static_cast<WordType>(v);
+      }
+
+    WordType sign = (Signed && (v < 0)) ? ~WordType(0) : WordType(0);
+    for (; i < WORD_COUNT; ++i) {
       val[i] = sign;
     }
   }
 
   LIBC_INLINE constexpr explicit BigInt(
-      const cpp::array<uint64_t, WORDCOUNT> &words) {
-    for (size_t i = 0; i < WORDCOUNT; ++i)
+      const cpp::array<WordType, WORD_COUNT> &words) {
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] = words[i];
   }
 
@@ -116,36 +119,37 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   template <typename T>
   LIBC_INLINE constexpr cpp::enable_if_t<
-      cpp::is_integral_v<T> && sizeof(T) <= 8 && !cpp::is_same_v<T, bool>, T>
+      cpp::is_integral_v<T> && !cpp::is_same_v<T, bool>, T>
   to() const {
-    return static_cast<T>(val[0]);
-  }
-  template <typename T>
-  LIBC_INLINE constexpr cpp::enable_if_t<
-      cpp::is_integral_v<T> && sizeof(T) == 16, T>
-  to() const {
-    // T is 128-bit.
     T lo = static_cast<T>(val[0]);
 
-    if constexpr (Bits == 64) {
-      if constexpr (Signed) {
-        // Extend sign for negative numbers.
-        return (val[0] >> 63) ? ((T(-1) << 64) + lo) : lo;
-      } else {
-        return lo;
-      }
-    } else {
-      return static_cast<T>((static_cast<T>(val[1]) << 64) + lo);
+    constexpr size_t T_BITS = sizeof(T) * CHAR_BIT;
+
+    if constexpr (T_BITS <= WORD_SIZE)
+      return lo;
+
+    constexpr size_t MAX_COUNT =
+        T_BITS > Bits ? WORD_COUNT : T_BITS / WORD_SIZE;
+    for (size_t i = 1; i < MAX_COUNT; ++i)
+      lo += static_cast<T>(val[i]) << (WORD_SIZE * i);
+
+    if constexpr (Signed && (T_BITS > Bits)) {
+      // Extend sign for negative numbers.
+      constexpr T MASK = (~T(0) << Bits);
+      if (val[WORD_COUNT - 1] >> (WORD_SIZE - 1))
+        lo |= MASK;
     }
+
+    return lo;
   }
 
   LIBC_INLINE constexpr explicit operator bool() const { return !is_zero(); }
 
-  LIBC_INLINE BigInt<Bits, Signed> &
-  operator=(const BigInt<Bits, Signed> &other) = default;
+  LIBC_INLINE BigInt<Bits, Signed, WordType> &
+  operator=(const BigInt<Bits, Signed, WordType> &other) = default;
 
   LIBC_INLINE constexpr bool is_zero() const {
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       if (val[i] != 0)
         return false;
     }
@@ -154,20 +158,20 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // Add x to this number and store the result in this number.
   // Returns the carry value produced by the addition operation.
-  LIBC_INLINE constexpr uint64_t add(const BigInt<Bits, Signed> &x) {
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr WordType add(const BigInt<Bits, Signed, WordType> &x) {
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry_const(val[i], x.val[i], s.carry);
       val[i] = s.sum;
     }
     return s.carry;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator+(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator+(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry(val[i], other.val[i], s.carry);
       result.val[i] = s.sum;
     }
@@ -176,58 +180,58 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // This will only apply when initializing a variable from constant values, so
   // it will always use the constexpr version of add_with_carry.
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator+(BigInt<Bits, Signed> &&other) const {
-    BigInt<Bits, Signed> result;
-    SumCarry<uint64_t> s{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator+(BigInt<Bits, Signed, WordType> &&other) const {
+    BigInt<Bits, Signed, WordType> result;
+    SumCarry<WordType> s{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       s = add_with_carry_const(val[i], other.val[i], s.carry);
       result.val[i] = s.sum;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator+=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator+=(const BigInt<Bits, Signed, WordType> &other) {
     add(other); // Returned carry value is ignored.
     return *this;
   }
 
   // Subtract x to this number and store the result in this number.
   // Returns the carry value produced by the subtraction operation.
-  LIBC_INLINE constexpr uint64_t sub(const BigInt<Bits, Signed> &x) {
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr WordType sub(const BigInt<Bits, Signed, WordType> &x) {
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow_const(val[i], x.val[i], d.borrow);
       val[i] = d.diff;
     }
     return d.borrow;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator-(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result;
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator-(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result;
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow(val[i], other.val[i], d.borrow);
       result.val[i] = d.diff;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator-(BigInt<Bits, Signed> &&other) const {
-    BigInt<Bits, Signed> result;
-    DiffBorrow<uint64_t> d{0, 0};
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator-(BigInt<Bits, Signed, WordType> &&other) const {
+    BigInt<Bits, Signed, WordType> result;
+    DiffBorrow<WordType> d{0, 0};
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
       d = sub_with_borrow_const(val[i], other.val[i], d.borrow);
       result.val[i] = d.diff;
     }
     return result;
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed> &
-  operator-=(const BigInt<Bits, Signed> &other) {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType> &
+  operator-=(const BigInt<Bits, Signed, WordType> &other) {
     // TODO(lntue): Set overflow flag / errno when carry is true.
     sub(other);
     return *this;
@@ -239,12 +243,12 @@ template <size_t Bits, bool Signed> struct BigInt {
   // the operations using 64-bit numbers. This ensures that we don't lose the
   // carry bits.
   // Returns the carry value produced by the multiplication operation.
-  LIBC_INLINE constexpr uint64_t mul(uint64_t x) {
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
-      NumberPair<uint64_t> prod = full_mul(val[i], x);
-      BigInt<128, Signed> tmp({prod.lo, prod.hi});
+  LIBC_INLINE constexpr WordType mul(WordType x) {
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
+      NumberPair<WordType> prod = full_mul(val[i], x);
+      BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
       carry += partial_sum.add(tmp);
       val[i] = partial_sum.val[0];
       partial_sum.val[0] = partial_sum.val[1];
@@ -254,33 +258,33 @@ template <size_t Bits, bool Signed> struct BigInt {
     return partial_sum.val[1];
   }
 
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  operator*(const BigInt<Bits, Signed> &other) const {
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  operator*(const BigInt<Bits, Signed, WordType> &other) const {
     if constexpr (Signed) {
-      BigInt<Bits, false> a(*this);
-      BigInt<Bits, false> b(other);
-      bool a_neg = (a.val[WORDCOUNT - 1] >> 63);
-      bool b_neg = (b.val[WORDCOUNT - 1] >> 63);
+      BigInt<Bits, false, WordType> a(*this);
+      BigInt<Bits, false, WordType> b(other);
+      bool a_neg = (a.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
+      bool b_neg = (b.val[WORD_COUNT - 1] >> (WORD_SIZE - 1));
       if (a_neg)
         a = -a;
       if (b_neg)
         b = -b;
-      BigInt<Bits, false> prod = a * b;
+      BigInt<Bits, false, WordType> prod = a * b;
       if (a_neg != b_neg)
         prod = -prod;
-      return static_cast<BigInt<Bits, true>>(prod);
+      return static_cast<BigInt<Bits, true, WordType>>(prod);
     } else {
 
-      if constexpr (WORDCOUNT == 1) {
+      if constexpr (WORD_COUNT == 1) {
         return {val[0] * other.val[0]};
       } else {
-        BigInt<Bits, Signed> result(0);
-        BigInt<128, Signed> partial_sum(0);
-        uint64_t carry = 0;
-        for (size_t i = 0; i < WORDCOUNT; ++i) {
+        BigInt<Bits, Signed, WordType> result(0);
+        BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+        WordType carry = 0;
+        for (size_t i = 0; i < WORD_COUNT; ++i) {
           for (size_t j = 0; j <= i; j++) {
-            NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-            BigInt<128, Signed> tmp({prod.lo, prod.hi});
+            NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+            BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
             carry += partial_sum.add(tmp);
           }
           result.val[i] = partial_sum.val[0];
@@ -295,19 +299,20 @@ template <size_t Bits, bool Signed> struct BigInt {
 
   // Return the full product, only unsigned for now.
   template <size_t OtherBits>
-  LIBC_INLINE constexpr BigInt<Bits + OtherBits, Signed>
-  ful_mul(const BigInt<OtherBits, Signed> &other) const {
-    BigInt<Bits + OtherBits, Signed> result(0);
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    constexpr size_t OTHER_WORDCOUNT = BigInt<OtherBits, Signed>::WORDCOUNT;
-    for (size_t i = 0; i <= WORDCOUNT + OTHER_WORDCOUNT - 2; ++i) {
+  LIBC_INLINE constexpr BigInt<Bits + OtherBits, Signed, WordType>
+  ful_mul(const BigInt<OtherBits, Signed, WordType> &other) const {
+    BigInt<Bits + OtherBits, Signed, WordType> result(0);
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    constexpr size_t OTHER_WORDCOUNT =
+        BigInt<OtherBits, Signed, WordType>::WORD_COUNT;
+    for (size_t i = 0; i <= WORD_COUNT + OTHER_WORDCOUNT - 2; ++i) {
       const size_t lower_idx =
           i < OTHER_WORDCOUNT ? 0 : i - OTHER_WORDCOUNT + 1;
-      const size_t upper_idx = i < WORDCOUNT ? i : WORDCOUNT - 1;
+      const size_t upper_idx = i < WORD_COUNT ? i : WORD_COUNT - 1;
       for (size_t j = lower_idx; j <= upper_idx; ++j) {
-        NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-        BigInt<128, Signed> tmp({prod.lo, prod.hi});
+        NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+        BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
         carry += partial_sum.add(tmp);
       }
       result.val[i] = partial_sum.val[0];
@@ -315,7 +320,7 @@ template <size_t Bits, bool Signed> struct BigInt {
       partial_sum.val[1] = carry;
       carry = 0;
     }
-    result.val[WORDCOUNT + OTHER_WORDCOUNT - 1] = partial_sum.val[0];
+    result.val[WORD_COUNT + OTHER_WORDCOUNT - 1] = partial_sum.val[0];
     return result;
   }
 
@@ -323,7 +328,7 @@ template <size_t Bits, bool Signed> struct BigInt {
   // `Bits` least significant bits of the full product, while this function will
   // approximate `Bits` most significant bits of the full product with errors
   // bounded by:
-  //   0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WORDCOUNT - 1.
+  //   0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WORD_COUNT - 1.
   //
   // An example usage of this is to quickly (but less accurately) compute the
   // product of (normalized) mantissas of floating point numbers:
@@ -335,44 +340,44 @@ template <size_t Bits, bool Signed> struct BigInt {
   //
   // Performance summary:
   //   Number of 64-bit x 64-bit -> 128-bit multiplications performed.
-  //   Bits  WORDCOUNT  ful_mul  quick_mul_hi  Error bound
+  //   Bits  WORD_COUNT  ful_mul  quick_mul_hi  Error bound
   //    128      2         4           3            1
   //    196      3         9           6            2
   //    256      4        16          10            3
   //    512      8        64          36            7
-  LIBC_INLINE constexpr BigInt<Bits, Signed>
-  quick_mul_hi(const BigInt<Bits, Signed> &other) const {
-    BigInt<Bits, Signed> result(0);
-    BigInt<128, Signed> partial_sum(0);
-    uint64_t carry = 0;
-    // First round of accumulation for those at WORDCOUNT - 1 in the full
+  LIBC_INLINE constexpr BigInt<Bits, Signed, WordType>
+  quick_mul_hi(const BigInt<Bits, Signed, WordType> &other) const {
+    BigInt<Bits, Signed, WordType> result(0);
+    BigInt<2 * WORD_SIZE, Signed, WordType> partial_sum(0);
+    WordType carry = 0;
+    // First round of accumulation for those at WORD_COUNT - 1 in the full
     // product.
-    for (size_t i = 0; i < WORDCOUNT; ++i) {
-      NumberPair<uint64_t> prod =
-          full_mul(val[i], other.val[WORDCOUNT - 1 - i]);
-      BigInt<128, Signed> tmp({prod.lo, prod.hi});
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
+      NumberPair<WordType> prod =
+          full_mul(val[i], other.val[WORD_COUNT - 1 - i]);
+      BigInt<2 * WORD_SIZE, Signed, WordType> tmp({prod.lo, prod.hi});
       carry += partial_sum.add(tmp);
     }
-    for (size_t i = WORDCOUNT; i < 2 * WORDCOUNT - 1; ++i) {
+    for (size_t i = WORD_COUNT; i < 2 * WORD_COUNT - 1; ++i) {
       partial_sum.val[0] = partial_sum.val[1];
       partial_sum.val[1] = carry;
       carry = 0;
-      for (size_t j = i - WORDCOUNT + 1; j < WORDCOUNT; ++j) {
-        NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
-        BigInt<128, Signed> tmp({prod.lo, prod.hi});
+      for (size_t j = i - WORD_COUNT + 1; j < WORD_COUNT; ++j) {
+        NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
+        BigInt<2 * WORD_SIZE, Signed, WordType> tmp({pro...
[truncated]

for (size_t j = 0; j <= i; j++) {
NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
BigInt<128, Signed> tmp({prod.lo, prod.hi});
NumberPair<WordType> prod = full_mul(val[j], other.val[i - j]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full_mull can only handle sizes up to uint64. Should we add a check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated full_mul to make it works for other types, and added tests for __uint128_t base type.

LIBC_INLINE constexpr optional<BigInt<Bits, Signed>>
div_uint32_times_pow_2(uint32_t x, size_t e) {
BigInt<Bits, Signed> remainder(0);
// template <typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you decided to not go with this design, should the code be deleted and comments reverted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the comments and fixed the input types.

@gchatelet
Copy link
Contributor

I haven't read the patch yet but just to make sure we're on the same page I'm cross posting this here #81267 (comment)

@lntue
Copy link
Contributor Author

lntue commented Feb 13, 2024

I haven't read the patch yet but just to make sure we're on the same page I'm cross posting this here #81267 (comment)

Thanks, that patch would definitely help with initializing UInt classes. Hopefully there won't be any issue merging this and #81267

Copy link
Contributor

@michaelrj-google michaelrj-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lntue lntue merged commit 4e00551 into llvm:main Feb 13, 2024
7 checks passed
@lntue lntue deleted the uint branch February 13, 2024 21:04
@gchatelet
Copy link
Contributor

gchatelet commented Feb 14, 2024

I haven't read the patch yet but just to make sure we're on the same page I'm cross posting this here #81267 (comment)

Thanks, that patch would definitely help with initializing UInt classes. Hopefully there won't be any issue merging this and #81267

I synced #81267 and unfortunately there is one issue

libc/src/__support/integer_literals.h:122:28: error: class template partial specialization contains a template parameter that cannot be deduced; this partial specialization will never be used [-Wunusable-partial-specialization]
template <size_t N> struct Parser<LIBC_NAMESPACE::cpp::UInt<N>> {
                           ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
libc/src/__support/integer_literals.h:122:18: note: non-deducible template parameter 'N'
template <size_t N> struct Parser<LIBC_NAMESPACE::cpp::UInt<N>> {

I need to think this through.

edit: I fixed it but the _u128 suffix will only work for WordType = uint64_t.

ASSERT_EQ(static_cast<int>(a), 3);
ASSERT_EQ(static_cast<uint64_t>(a), uint64_t(0x2'0000'0003));
ASSERT_EQ(static_cast<int>(a >> 32), 2);
ASSERT_EQ(static_cast<int>(a >> (128 + 32)), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be ASSERT_EQ(static_cast<int>(a >> (128 + 32)), 0);?
This line started failing when running tests for #86137 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, after L689, the value of a is 2^(128+32) + 2 × 2^32 + 3, so right shift by 128 + 32 should make it 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a bug in my code :) Sorry for the noise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants