Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[APInt] Restore multiplicativeInverse with explicit modulus and better testing #87812

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented Apr 5, 2024

This reverts commit 0b293e8 and adds additional improvements.

There are out-of-tree uses of this method, and it is planned to be used as part of a new polynomial dialect in MLIR, a starting PR of which is #72081 (later PRs will add lowerings that need the removed functionality)

…lvm#87644)"

This reverts commit 0b293e8.

There are out-of-tree uses of this method, and it is planned to be used
as part of a new polynomial dialect in MLIR, a starting PR of which is
llvm#72081 (later PRs will add
lowerings that need the removed functionality)
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-support

Author: Jeremy Kun (j2kun)

Changes

…87644)"

This reverts commit 0b293e8.

There are out-of-tree uses of this method, and it is planned to be used as part of a new polynomial dialect in MLIR, a starting PR of which is #72081 (later PRs will add lowerings that need the removed functionality)


Full diff: https://github.com/llvm/llvm-project/pull/87812.diff

3 Files Affected:

  • (modified) llvm/include/llvm/ADT/APInt.h (+3)
  • (modified) llvm/lib/Support/APInt.cpp (+49)
  • (modified) llvm/unittests/ADT/APIntTest.cpp (+15-4)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 8d3c029b2e7e91..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1740,6 +1740,9 @@ class [[nodiscard]] APInt {
     return *this;
   }
 
+  /// \returns the multiplicative inverse for a given modulo.
+  APInt multiplicativeInverse(const APInt &modulo) const;
+
   /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
   APInt multiplicativeInverse() const;
 
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 224ea0924f0aaa..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1240,6 +1240,55 @@ APInt APInt::sqrt() const {
   return x_old + 1;
 }
 
+/// Computes the multiplicative inverse of this APInt for a given modulo. The
+/// iterative extended Euclidean algorithm is used to solve for this value,
+/// however we simplify it to speed up calculating only the inverse, and take
+/// advantage of div+rem calculations. We also use some tricks to avoid copying
+/// (potentially large) APInts around.
+/// WARNING: a value of '0' may be returned,
+///          signifying that no multiplicative inverse exists!
+APInt APInt::multiplicativeInverse(const APInt& modulo) const {
+  assert(ult(modulo) && "This APInt must be smaller than the modulo");
+
+  // Using the properties listed at the following web page (accessed 06/21/08):
+  //   http://www.numbertheory.org/php/euclid.html
+  // (especially the properties numbered 3, 4 and 9) it can be proved that
+  // BitWidth bits suffice for all the computations in the algorithm implemented
+  // below. More precisely, this number of bits suffice if the multiplicative
+  // inverse exists, but may not suffice for the general extended Euclidean
+  // algorithm.
+
+  APInt r[2] = { modulo, *this };
+  APInt t[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+  APInt q(BitWidth, 0);
+
+  unsigned i;
+  for (i = 0; r[i^1] != 0; i ^= 1) {
+    // An overview of the math without the confusing bit-flipping:
+    // q = r[i-2] / r[i-1]
+    // r[i] = r[i-2] % r[i-1]
+    // t[i] = t[i-2] - t[i-1] * q
+    udivrem(r[i], r[i^1], q, r[i]);
+    t[i] -= t[i^1] * q;
+  }
+
+  // If this APInt and the modulo are not coprime, there is no multiplicative
+  // inverse, so return 0. We check this by looking at the next-to-last
+  // remainder, which is the gcd(*this,modulo) as calculated by the Euclidean
+  // algorithm.
+  if (r[i] != 1)
+    return APInt(BitWidth, 0);
+
+  // The next-to-last t is the multiplicative inverse.  However, we are
+  // interested in a positive inverse. Calculate a positive one from a negative
+  // one if necessary. A simple addition of the modulo suffices because
+  // abs(t[i]) is known to be less than *this/2 (see the link above).
+  if (t[i].isNegative())
+    t[i] += modulo;
+
+  return std::move(t[i]);
+}
+
 /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
 APInt APInt::multiplicativeInverse() const {
   assert((*this)[0] &&
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 76fc26412407e7..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3249,11 +3249,22 @@ TEST(APIntTest, SolveQuadraticEquationWrap) {
 }
 
 TEST(APIntTest, MultiplicativeInverseExaustive) {
-  for (unsigned BitWidth = 1; BitWidth <= 8; ++BitWidth) {
-    for (unsigned Value = 1; Value < (1u << BitWidth); Value += 2) {
-      // Multiplicative inverse exists for all odd numbers.
+  for (unsigned BitWidth = 1; BitWidth <= 16; ++BitWidth) {
+    for (unsigned Value = 0; Value < (1u << BitWidth); ++Value) {
       APInt V = APInt(BitWidth, Value);
-      EXPECT_EQ(V * V.multiplicativeInverse(), 1);
+      APInt MulInv =
+          V.zext(BitWidth + 1)
+              .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
+              .trunc(BitWidth);
+      APInt One = V * MulInv;
+      if (V[0]) {
+        // Multiplicative inverse exists for all odd numbers.
+        EXPECT_TRUE(One.isOne());
+        EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
+      } else {
+        // Multiplicative inverse does not exist for even numbers (and 0).
+        EXPECT_TRUE(MulInv.isZero());
+      }
     }
   }
 }

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-adt

Author: Jeremy Kun (j2kun)

Changes

…87644)"

This reverts commit 0b293e8.

There are out-of-tree uses of this method, and it is planned to be used as part of a new polynomial dialect in MLIR, a starting PR of which is #72081 (later PRs will add lowerings that need the removed functionality)


Full diff: https://github.com/llvm/llvm-project/pull/87812.diff

3 Files Affected:

  • (modified) llvm/include/llvm/ADT/APInt.h (+3)
  • (modified) llvm/lib/Support/APInt.cpp (+49)
  • (modified) llvm/unittests/ADT/APIntTest.cpp (+15-4)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 8d3c029b2e7e91..bd1716219ee5fc 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -1740,6 +1740,9 @@ class [[nodiscard]] APInt {
     return *this;
   }
 
+  /// \returns the multiplicative inverse for a given modulo.
+  APInt multiplicativeInverse(const APInt &modulo) const;
+
   /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
   APInt multiplicativeInverse() const;
 
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 224ea0924f0aaa..f8f699f8f6ccd7 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -1240,6 +1240,55 @@ APInt APInt::sqrt() const {
   return x_old + 1;
 }
 
+/// Computes the multiplicative inverse of this APInt for a given modulo. The
+/// iterative extended Euclidean algorithm is used to solve for this value,
+/// however we simplify it to speed up calculating only the inverse, and take
+/// advantage of div+rem calculations. We also use some tricks to avoid copying
+/// (potentially large) APInts around.
+/// WARNING: a value of '0' may be returned,
+///          signifying that no multiplicative inverse exists!
+APInt APInt::multiplicativeInverse(const APInt& modulo) const {
+  assert(ult(modulo) && "This APInt must be smaller than the modulo");
+
+  // Using the properties listed at the following web page (accessed 06/21/08):
+  //   http://www.numbertheory.org/php/euclid.html
+  // (especially the properties numbered 3, 4 and 9) it can be proved that
+  // BitWidth bits suffice for all the computations in the algorithm implemented
+  // below. More precisely, this number of bits suffice if the multiplicative
+  // inverse exists, but may not suffice for the general extended Euclidean
+  // algorithm.
+
+  APInt r[2] = { modulo, *this };
+  APInt t[2] = { APInt(BitWidth, 0), APInt(BitWidth, 1) };
+  APInt q(BitWidth, 0);
+
+  unsigned i;
+  for (i = 0; r[i^1] != 0; i ^= 1) {
+    // An overview of the math without the confusing bit-flipping:
+    // q = r[i-2] / r[i-1]
+    // r[i] = r[i-2] % r[i-1]
+    // t[i] = t[i-2] - t[i-1] * q
+    udivrem(r[i], r[i^1], q, r[i]);
+    t[i] -= t[i^1] * q;
+  }
+
+  // If this APInt and the modulo are not coprime, there is no multiplicative
+  // inverse, so return 0. We check this by looking at the next-to-last
+  // remainder, which is the gcd(*this,modulo) as calculated by the Euclidean
+  // algorithm.
+  if (r[i] != 1)
+    return APInt(BitWidth, 0);
+
+  // The next-to-last t is the multiplicative inverse.  However, we are
+  // interested in a positive inverse. Calculate a positive one from a negative
+  // one if necessary. A simple addition of the modulo suffices because
+  // abs(t[i]) is known to be less than *this/2 (see the link above).
+  if (t[i].isNegative())
+    t[i] += modulo;
+
+  return std::move(t[i]);
+}
+
 /// \returns the multiplicative inverse of an odd APInt modulo 2^BitWidth.
 APInt APInt::multiplicativeInverse() const {
   assert((*this)[0] &&
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 76fc26412407e7..23f9ee2d39c441 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -3249,11 +3249,22 @@ TEST(APIntTest, SolveQuadraticEquationWrap) {
 }
 
 TEST(APIntTest, MultiplicativeInverseExaustive) {
-  for (unsigned BitWidth = 1; BitWidth <= 8; ++BitWidth) {
-    for (unsigned Value = 1; Value < (1u << BitWidth); Value += 2) {
-      // Multiplicative inverse exists for all odd numbers.
+  for (unsigned BitWidth = 1; BitWidth <= 16; ++BitWidth) {
+    for (unsigned Value = 0; Value < (1u << BitWidth); ++Value) {
       APInt V = APInt(BitWidth, Value);
-      EXPECT_EQ(V * V.multiplicativeInverse(), 1);
+      APInt MulInv =
+          V.zext(BitWidth + 1)
+              .multiplicativeInverse(APInt::getSignedMinValue(BitWidth + 1))
+              .trunc(BitWidth);
+      APInt One = V * MulInv;
+      if (V[0]) {
+        // Multiplicative inverse exists for all odd numbers.
+        EXPECT_TRUE(One.isOne());
+        EXPECT_TRUE((V * V.multiplicativeInverse()).isOne());
+      } else {
+        // Multiplicative inverse does not exist for even numbers (and 0).
+        EXPECT_TRUE(MulInv.isZero());
+      }
     }
   }
 }

Copy link

github-actions bot commented Apr 5, 2024

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Apr 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

llvm/include/llvm/ADT/APInt.h Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/APInt.h Outdated Show resolved Hide resolved
llvm/unittests/ADT/APIntTest.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APInt.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APInt.cpp Outdated Show resolved Hide resolved
@kuhar
Copy link
Member

kuhar commented Apr 5, 2024

…87644)"

This reverts commit 0b293e8.

There are out-of-tree uses of this method, and it is planned to be used as part of a new polynomial dialect in MLIR, a starting PR of which is #72081 (later PRs will add lowerings that need the removed functionality)

I'm generally sympathetic to downstream maintainers, but this seems like an unusual reason for revert. Can't these users implement it on their own as a free function? If we have it upstream with no uses in the codebase, this will lead to eventual bitrot.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want this to be a revert, don't apply any other changes even if this leads to CI complaining about formatting etc.

But from the sounds of it, I don't think we should revert this one.

@j2kun j2kun changed the title Revert "[APInt] Remove multiplicativeInverse with explicit modulus (#… [APInt] Restore multiplicativeInverse with explicit modulus and better testing Apr 5, 2024
@j2kun
Copy link
Contributor Author

j2kun commented Apr 5, 2024

Adding review comments to a "revert" PR sends mixed messages :)

I'm happy to sit on this PR until there's an in-tree user, which I intend to add. In the mean time you can let me know if the testing is exhaustive enough for LLVM standards. I did not change the implementation.

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the extra testing! I just have a few nits inline.

llvm/unittests/ADT/APIntTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/APIntTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/APIntTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/APIntTest.cpp Show resolved Hide resolved
@j2kun j2kun force-pushed the revert-apint-inverse-removal branch from d6bf8ab to ea08b21 Compare April 7, 2024 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants