Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlir/Presburger/MPInt: move into llvm/ADT #94953

Merged
merged 2 commits into from
Jun 12, 2024
Merged

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Jun 10, 2024

MPInt is an arbitrary-precision integer library that builds on top of APInt, and has a fast-path when the number fits within 64 bits. It was originally written for the Presburger library in MLIR, but seems useful to the LLVM project in general, independently of the Presburger library or MLIR. Hence, move it into LLVM/ADT under the name DynamicAPInt.

This patch is part of a project to move the Presburger library into LLVM.

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 10, 2024

@llvm/pr-subscribers-mlir-func
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-presburger
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Ramkumar Ramachandra (artagnon)

Changes

MPInt is an arbitrary-precision integer library that builds on top of APInt, and has a fast-path when the number fits within 64 bits. It was originally written for the Presburger library in MLIR, but seems useful to the LLVM project in general, independently of the Presburger library or MLIR. Hence, move it into LLVM/ADT.

This patch is part of a project to move the Presburger library into LLVM.


Patch is 107.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94953.diff

26 Files Affected:

  • (added) llvm/include/llvm/ADT/MPInt.h (+644)
  • (added) llvm/include/llvm/ADT/SlowMPInt.h (+138)
  • (modified) llvm/lib/Support/CMakeLists.txt (+2)
  • (added) llvm/lib/Support/MPInt.cpp (+25)
  • (added) llvm/lib/Support/SlowMPInt.cpp (+276)
  • (modified) llvm/unittests/ADT/CMakeLists.txt (+1)
  • (added) llvm/unittests/ADT/MPIntTest.cpp (+200)
  • (modified) mlir/include/mlir/Analysis/Presburger/Fraction.h (+2-2)
  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+3)
  • (removed) mlir/include/mlir/Analysis/Presburger/MPInt.h (-617)
  • (removed) mlir/include/mlir/Analysis/Presburger/SlowMPInt.h (-136)
  • (modified) mlir/include/mlir/Analysis/Presburger/Utils.h (+1-2)
  • (modified) mlir/include/mlir/Support/LLVM.h (+2)
  • (modified) mlir/lib/Analysis/Presburger/CMakeLists.txt (-2)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+4-4)
  • (modified) mlir/lib/Analysis/Presburger/LinearTransform.cpp (+1-1)
  • (removed) mlir/lib/Analysis/Presburger/MPInt.cpp (-38)
  • (modified) mlir/lib/Analysis/Presburger/Matrix.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PWMAFunction.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PresburgerRelation.cpp (-1)
  • (modified) mlir/lib/Analysis/Presburger/Simplex.cpp (+2-2)
  • (removed) mlir/lib/Analysis/Presburger/SlowMPInt.cpp (-290)
  • (modified) mlir/lib/Analysis/Presburger/Utils.cpp (+6-5)
  • (modified) mlir/unittests/Analysis/Presburger/CMakeLists.txt (-1)
  • (removed) mlir/unittests/Analysis/Presburger/MPIntTest.cpp (-200)
  • (modified) mlir/unittests/Analysis/Presburger/Utils.h (+2)
diff --git a/llvm/include/llvm/ADT/MPInt.h b/llvm/include/llvm/ADT/MPInt.h
new file mode 100644
index 0000000000000..dc387d7d0e5db
--- /dev/null
+++ b/llvm/include/llvm/ADT/MPInt.h
@@ -0,0 +1,644 @@
+//===- MPInt.h - MPInt Class ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a simple class to represent arbitrary precision signed integers.
+// Unlike APInt, one does not have to specify a fixed maximum size, and the
+// integer can take on any arbitrary values. This is optimized for small-values
+// by providing fast-paths for the cases when the value stored fits in 64-bits.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_MPINT_H
+#define LLVM_ADT_MPINT_H
+
+#include "llvm/ADT/SlowMPInt.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+namespace llvm {
+namespace detail {
+/// ---------------------------------------------------------------------------
+/// Some helpers from MLIR/MathExtras.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t ceilDiv(int64_t Numerator,
+                                             int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? ((Numerator + X) / Denominator) + 1
+                  : -(-Numerator / Denominator);
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t floorDiv(int64_t Numerator,
+                                              int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? Numerator / Denominator
+                  : -((-Numerator + X) / Denominator) - 1;
+}
+
+/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
+/// always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t mod(int64_t Numerator,
+                                         int64_t Denominator) {
+  assert(Denominator >= 1);
+  return Numerator % Denominator < 0 ? Numerator % Denominator + Denominator
+                                     : Numerator % Denominator;
+}
+
+/// If builtin intrinsics for overflow-checked arithmetic are available,
+/// use them. Otherwise, call through to LLVM's overflow-checked arithmetic
+/// functionality. Those functions also have such macro-gated uses of intrinsics
+/// but they are not always_inlined, which is important for us to achieve
+/// high-performance; calling the functions directly would result in a slowdown
+/// of 1.15x.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_add_overflow)
+  return __builtin_add_overflow(X, Y, &Result);
+#else
+  return AddOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool subOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_sub_overflow)
+  return __builtin_sub_overflow(X, Y, &Result);
+#else
+  return SubOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool mulOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_mul_overflow)
+  return __builtin_mul_overflow(X, Y, &Result);
+#else
+  return MulOverflow(x, y, result);
+#endif
+}
+} // namespace detail
+
+/// This class provides support for multi-precision arithmetic.
+///
+/// Unlike APInt, this extends the precision as necessary to prevent overflows
+/// and supports operations between objects with differing internal precisions.
+///
+/// This is optimized for small-values by providing fast-paths for the cases
+/// when the value stored fits in 64-bits. We annotate all fastpaths by using
+/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in
+/// a 1.2x performance slowdown.
+///
+/// We always_inline all operations; removing these results in a 1.5x
+/// performance slowdown.
+///
+/// When holdsLarge is true, a SlowMPInt is held in the union. If it is false,
+/// the int64_t is held. Using std::variant instead would lead to significantly
+/// worse performance.
+class MPInt {
+private:
+  union {
+    int64_t ValSmall;
+    detail::SlowMPInt ValLarge;
+  };
+  unsigned HoldsLarge;
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initSmall(int64_t O) {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+    ValSmall = O;
+    HoldsLarge = false;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initLarge(const detail::SlowMPInt &O) {
+    if (LLVM_LIKELY(isSmall())) {
+      // The data in memory could be in an arbitrary state, not necessarily
+      // corresponding to any valid state of ValLarge; we cannot call any member
+      // functions, e.g. the assignment operator on it, as they may access the
+      // invalid internal state. We instead construct a new object using
+      // placement new.
+      new (&ValLarge) detail::SlowMPInt(O);
+    } else {
+      // In this case, we need to use the assignment operator, because if we use
+      // placement-new as above we would lose track of allocated memory
+      // and leak it.
+      ValLarge = O;
+    }
+    HoldsLarge = true;
+  }
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(const detail::SlowMPInt &Val)
+      : ValLarge(Val), HoldsLarge(true) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isSmall() const { return !HoldsLarge; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isLarge() const { return HoldsLarge; }
+  /// Get the stored value. For getSmall/Large,
+  /// the stored value should be small/large.
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t getSmall() const {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t &getSmall() {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE const detail::SlowMPInt &getLarge() const {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE detail::SlowMPInt &getLarge() {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  explicit operator detail::SlowMPInt() const {
+    if (isSmall())
+      return detail::SlowMPInt(getSmall());
+    return getLarge();
+  }
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(int64_t Val)
+      : ValSmall(Val), HoldsLarge(false) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt() : MPInt(0) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE ~MPInt() {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt(const MPInt &O)
+      : ValSmall(O.ValSmall), HoldsLarge(false) {
+    if (LLVM_UNLIKELY(O.isLarge()))
+      initLarge(O.ValLarge);
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(const MPInt &O) {
+    if (LLVM_LIKELY(O.isSmall())) {
+      initSmall(O.ValSmall);
+      return *this;
+    }
+    initLarge(O.ValLarge);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(int X) {
+    initSmall(X);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit operator int64_t() const {
+    if (isSmall())
+      return getSmall();
+    return static_cast<int64_t>(getLarge());
+  }
+
+  bool operator==(const MPInt &O) const;
+  bool operator!=(const MPInt &O) const;
+  bool operator>(const MPInt &O) const;
+  bool operator<(const MPInt &O) const;
+  bool operator<=(const MPInt &O) const;
+  bool operator>=(const MPInt &O) const;
+  MPInt operator+(const MPInt &O) const;
+  MPInt operator-(const MPInt &O) const;
+  MPInt operator*(const MPInt &O) const;
+  MPInt operator/(const MPInt &O) const;
+  MPInt operator%(const MPInt &O) const;
+  MPInt &operator+=(const MPInt &O);
+  MPInt &operator-=(const MPInt &O);
+  MPInt &operator*=(const MPInt &O);
+  MPInt &operator/=(const MPInt &O);
+  MPInt &operator%=(const MPInt &O);
+  MPInt operator-() const;
+  MPInt &operator++();
+  MPInt &operator--();
+
+  // Divide by a number that is known to be positive.
+  // This is slightly more efficient because it saves an overflow check.
+  MPInt divByPositive(const MPInt &O) const;
+  MPInt &divByPositiveInPlace(const MPInt &O);
+
+  friend MPInt abs(const MPInt &X);
+  friend MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS);
+  friend MPInt floorDiv(const MPInt &LHS, const MPInt &RHS);
+  // The operands must be non-negative for gcd.
+  friend MPInt gcd(const MPInt &A, const MPInt &B);
+  friend MPInt lcm(const MPInt &A, const MPInt &B);
+  friend MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+  /// ---------------------------------------------------------------------------
+  /// Convenience operator overloads for int64_t.
+  /// ---------------------------------------------------------------------------
+  friend MPInt &operator+=(MPInt &A, int64_t B);
+  friend MPInt &operator-=(MPInt &A, int64_t B);
+  friend MPInt &operator*=(MPInt &A, int64_t B);
+  friend MPInt &operator/=(MPInt &A, int64_t B);
+  friend MPInt &operator%=(MPInt &A, int64_t B);
+
+  friend bool operator==(const MPInt &A, int64_t B);
+  friend bool operator!=(const MPInt &A, int64_t B);
+  friend bool operator>(const MPInt &A, int64_t B);
+  friend bool operator<(const MPInt &A, int64_t B);
+  friend bool operator<=(const MPInt &A, int64_t B);
+  friend bool operator>=(const MPInt &A, int64_t B);
+  friend MPInt operator+(const MPInt &A, int64_t B);
+  friend MPInt operator-(const MPInt &A, int64_t B);
+  friend MPInt operator*(const MPInt &A, int64_t B);
+  friend MPInt operator/(const MPInt &A, int64_t B);
+  friend MPInt operator%(const MPInt &A, int64_t B);
+
+  friend bool operator==(int64_t A, const MPInt &B);
+  friend bool operator!=(int64_t A, const MPInt &B);
+  friend bool operator>(int64_t A, const MPInt &B);
+  friend bool operator<(int64_t A, const MPInt &B);
+  friend bool operator<=(int64_t A, const MPInt &B);
+  friend bool operator>=(int64_t A, const MPInt &B);
+  friend MPInt operator+(int64_t A, const MPInt &B);
+  friend MPInt operator-(int64_t A, const MPInt &B);
+  friend MPInt operator*(int64_t A, const MPInt &B);
+  friend MPInt operator/(int64_t A, const MPInt &B);
+  friend MPInt operator%(int64_t A, const MPInt &B);
+
+  friend hash_code hash_value(const MPInt &x); // NOLINT
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  raw_ostream &print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+inline raw_ostream &operator<<(raw_ostream &OS, const MPInt &X) {
+  X.print(OS);
+  return OS;
+}
+#endif
+
+/// Redeclarations of friend declaration above to
+/// make it discoverable by lookups.
+hash_code hash_value(const MPInt &X); // NOLINT
+
+/// This just calls through to the operator int64_t, but it's useful when a
+/// function pointer is required. (Although this is marked inline, it is still
+/// possible to obtain and use a function pointer to this.)
+static inline int64_t int64FromMPInt(const MPInt &X) { return int64_t(X); }
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mpintFromInt64(int64_t X) {
+  return MPInt(X);
+}
+
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+namespace detail {
+// Division overflows only when trying to negate the minimal signed value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool divWouldOverflow(int64_t X, int64_t Y) {
+  return X == std::numeric_limits<int64_t>::min() && Y == -1;
+}
+} // namespace detail
+
+/// We define the operations here in the header to facilitate inlining.
+
+/// ---------------------------------------------------------------------------
+/// Comparison operators.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator==(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() == O.getSmall();
+  return detail::SlowMPInt(*this) == detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator!=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() != O.getSmall();
+  return detail::SlowMPInt(*this) != detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() > O.getSmall();
+  return detail::SlowMPInt(*this) > detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() < O.getSmall();
+  return detail::SlowMPInt(*this) < detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() <= O.getSmall();
+  return detail::SlowMPInt(*this) <= detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() >= O.getSmall();
+  return detail::SlowMPInt(*this) >= detail::SlowMPInt(O);
+}
+
+/// ---------------------------------------------------------------------------
+/// Arithmetic operators.
+/// ---------------------------------------------------------------------------
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator+(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::addOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::subOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator*(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::mulOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+}
+
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::divByPositive(const MPInt &O) const {
+  assert(O > 0);
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() / O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator/(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    // Division overflows only occur when negating the minimal possible value.
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), O.getSmall())))
+      return -*this;
+    return MPInt(getSmall() / O.getSmall());
+  }
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &X) {
+  return MPInt(X >= 0 ? X : -X);
+}
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::ceilDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(ceilDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt floorDiv(const MPInt &LHS,
+                                            const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::floorDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(floorDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall()))
+    return MPInt(detail::mod(LHS.getSmall(), RHS.getSmall()));
+  return MPInt(mod(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &A, const MPInt &B) {
+  assert(A >= 0 && B >= 0 && "operands must be non-negative!");
+  if (LLVM_LIKELY(A.isSmall() && B.isSmall()))
+    return MPInt(std::gcd(A.getSmall(), B.getSmall()));
+  return MPInt(gcd(detail::SlowMPInt(A), detail::SlowMPInt(B)));
+}
+
+/// Returns the least common multiple of A and B.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt lcm(const MPInt &A, const MPInt &B) {
+  MPInt X = abs(A);
+  MPInt Y = abs(B);
+  return (X * Y) / gcd(X, Y);
+}
+
+/// This operation cannot overflow.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator%(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() % O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) % detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-() const {
+  if (LLVM_LIKELY(isSmall())) {
+    if (LLVM_LIKELY(getSmall() != std::numeric_limits<int64_t>::min()))
+      return MPInt(-getSmall());
+    return MPInt(-detail::SlowMPInt(*this));
+  }
+  return MPInt(-detail::SlowMPInt(*this));
+}
+
+/// ---------------------------------------------------------------------------
+/// Assignment operators, preincrement, predecrement.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator+=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::addOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator-=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::subOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator*=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::mulOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly require...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 10, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Ramkumar Ramachandra (artagnon)

Changes

MPInt is an arbitrary-precision integer library that builds on top of APInt, and has a fast-path when the number fits within 64 bits. It was originally written for the Presburger library in MLIR, but seems useful to the LLVM project in general, independently of the Presburger library or MLIR. Hence, move it into LLVM/ADT.

This patch is part of a project to move the Presburger library into LLVM.


Patch is 107.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94953.diff

26 Files Affected:

  • (added) llvm/include/llvm/ADT/MPInt.h (+644)
  • (added) llvm/include/llvm/ADT/SlowMPInt.h (+138)
  • (modified) llvm/lib/Support/CMakeLists.txt (+2)
  • (added) llvm/lib/Support/MPInt.cpp (+25)
  • (added) llvm/lib/Support/SlowMPInt.cpp (+276)
  • (modified) llvm/unittests/ADT/CMakeLists.txt (+1)
  • (added) llvm/unittests/ADT/MPIntTest.cpp (+200)
  • (modified) mlir/include/mlir/Analysis/Presburger/Fraction.h (+2-2)
  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+3)
  • (removed) mlir/include/mlir/Analysis/Presburger/MPInt.h (-617)
  • (removed) mlir/include/mlir/Analysis/Presburger/SlowMPInt.h (-136)
  • (modified) mlir/include/mlir/Analysis/Presburger/Utils.h (+1-2)
  • (modified) mlir/include/mlir/Support/LLVM.h (+2)
  • (modified) mlir/lib/Analysis/Presburger/CMakeLists.txt (-2)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+4-4)
  • (modified) mlir/lib/Analysis/Presburger/LinearTransform.cpp (+1-1)
  • (removed) mlir/lib/Analysis/Presburger/MPInt.cpp (-38)
  • (modified) mlir/lib/Analysis/Presburger/Matrix.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PWMAFunction.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PresburgerRelation.cpp (-1)
  • (modified) mlir/lib/Analysis/Presburger/Simplex.cpp (+2-2)
  • (removed) mlir/lib/Analysis/Presburger/SlowMPInt.cpp (-290)
  • (modified) mlir/lib/Analysis/Presburger/Utils.cpp (+6-5)
  • (modified) mlir/unittests/Analysis/Presburger/CMakeLists.txt (-1)
  • (removed) mlir/unittests/Analysis/Presburger/MPIntTest.cpp (-200)
  • (modified) mlir/unittests/Analysis/Presburger/Utils.h (+2)
diff --git a/llvm/include/llvm/ADT/MPInt.h b/llvm/include/llvm/ADT/MPInt.h
new file mode 100644
index 0000000000000..dc387d7d0e5db
--- /dev/null
+++ b/llvm/include/llvm/ADT/MPInt.h
@@ -0,0 +1,644 @@
+//===- MPInt.h - MPInt Class ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a simple class to represent arbitrary precision signed integers.
+// Unlike APInt, one does not have to specify a fixed maximum size, and the
+// integer can take on any arbitrary values. This is optimized for small-values
+// by providing fast-paths for the cases when the value stored fits in 64-bits.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_MPINT_H
+#define LLVM_ADT_MPINT_H
+
+#include "llvm/ADT/SlowMPInt.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+namespace llvm {
+namespace detail {
+/// ---------------------------------------------------------------------------
+/// Some helpers from MLIR/MathExtras.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t ceilDiv(int64_t Numerator,
+                                             int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? ((Numerator + X) / Denominator) + 1
+                  : -(-Numerator / Denominator);
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t floorDiv(int64_t Numerator,
+                                              int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? Numerator / Denominator
+                  : -((-Numerator + X) / Denominator) - 1;
+}
+
+/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
+/// always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t mod(int64_t Numerator,
+                                         int64_t Denominator) {
+  assert(Denominator >= 1);
+  return Numerator % Denominator < 0 ? Numerator % Denominator + Denominator
+                                     : Numerator % Denominator;
+}
+
+/// If builtin intrinsics for overflow-checked arithmetic are available,
+/// use them. Otherwise, call through to LLVM's overflow-checked arithmetic
+/// functionality. Those functions also have such macro-gated uses of intrinsics
+/// but they are not always_inlined, which is important for us to achieve
+/// high-performance; calling the functions directly would result in a slowdown
+/// of 1.15x.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_add_overflow)
+  return __builtin_add_overflow(X, Y, &Result);
+#else
+  return AddOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool subOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_sub_overflow)
+  return __builtin_sub_overflow(X, Y, &Result);
+#else
+  return SubOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool mulOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_mul_overflow)
+  return __builtin_mul_overflow(X, Y, &Result);
+#else
+  return MulOverflow(x, y, result);
+#endif
+}
+} // namespace detail
+
+/// This class provides support for multi-precision arithmetic.
+///
+/// Unlike APInt, this extends the precision as necessary to prevent overflows
+/// and supports operations between objects with differing internal precisions.
+///
+/// This is optimized for small-values by providing fast-paths for the cases
+/// when the value stored fits in 64-bits. We annotate all fastpaths by using
+/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in
+/// a 1.2x performance slowdown.
+///
+/// We always_inline all operations; removing these results in a 1.5x
+/// performance slowdown.
+///
+/// When holdsLarge is true, a SlowMPInt is held in the union. If it is false,
+/// the int64_t is held. Using std::variant instead would lead to significantly
+/// worse performance.
+class MPInt {
+private:
+  union {
+    int64_t ValSmall;
+    detail::SlowMPInt ValLarge;
+  };
+  unsigned HoldsLarge;
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initSmall(int64_t O) {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+    ValSmall = O;
+    HoldsLarge = false;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initLarge(const detail::SlowMPInt &O) {
+    if (LLVM_LIKELY(isSmall())) {
+      // The data in memory could be in an arbitrary state, not necessarily
+      // corresponding to any valid state of ValLarge; we cannot call any member
+      // functions, e.g. the assignment operator on it, as they may access the
+      // invalid internal state. We instead construct a new object using
+      // placement new.
+      new (&ValLarge) detail::SlowMPInt(O);
+    } else {
+      // In this case, we need to use the assignment operator, because if we use
+      // placement-new as above we would lose track of allocated memory
+      // and leak it.
+      ValLarge = O;
+    }
+    HoldsLarge = true;
+  }
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(const detail::SlowMPInt &Val)
+      : ValLarge(Val), HoldsLarge(true) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isSmall() const { return !HoldsLarge; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isLarge() const { return HoldsLarge; }
+  /// Get the stored value. For getSmall/Large,
+  /// the stored value should be small/large.
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t getSmall() const {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t &getSmall() {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE const detail::SlowMPInt &getLarge() const {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE detail::SlowMPInt &getLarge() {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  explicit operator detail::SlowMPInt() const {
+    if (isSmall())
+      return detail::SlowMPInt(getSmall());
+    return getLarge();
+  }
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(int64_t Val)
+      : ValSmall(Val), HoldsLarge(false) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt() : MPInt(0) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE ~MPInt() {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt(const MPInt &O)
+      : ValSmall(O.ValSmall), HoldsLarge(false) {
+    if (LLVM_UNLIKELY(O.isLarge()))
+      initLarge(O.ValLarge);
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(const MPInt &O) {
+    if (LLVM_LIKELY(O.isSmall())) {
+      initSmall(O.ValSmall);
+      return *this;
+    }
+    initLarge(O.ValLarge);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(int X) {
+    initSmall(X);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit operator int64_t() const {
+    if (isSmall())
+      return getSmall();
+    return static_cast<int64_t>(getLarge());
+  }
+
+  bool operator==(const MPInt &O) const;
+  bool operator!=(const MPInt &O) const;
+  bool operator>(const MPInt &O) const;
+  bool operator<(const MPInt &O) const;
+  bool operator<=(const MPInt &O) const;
+  bool operator>=(const MPInt &O) const;
+  MPInt operator+(const MPInt &O) const;
+  MPInt operator-(const MPInt &O) const;
+  MPInt operator*(const MPInt &O) const;
+  MPInt operator/(const MPInt &O) const;
+  MPInt operator%(const MPInt &O) const;
+  MPInt &operator+=(const MPInt &O);
+  MPInt &operator-=(const MPInt &O);
+  MPInt &operator*=(const MPInt &O);
+  MPInt &operator/=(const MPInt &O);
+  MPInt &operator%=(const MPInt &O);
+  MPInt operator-() const;
+  MPInt &operator++();
+  MPInt &operator--();
+
+  // Divide by a number that is known to be positive.
+  // This is slightly more efficient because it saves an overflow check.
+  MPInt divByPositive(const MPInt &O) const;
+  MPInt &divByPositiveInPlace(const MPInt &O);
+
+  friend MPInt abs(const MPInt &X);
+  friend MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS);
+  friend MPInt floorDiv(const MPInt &LHS, const MPInt &RHS);
+  // The operands must be non-negative for gcd.
+  friend MPInt gcd(const MPInt &A, const MPInt &B);
+  friend MPInt lcm(const MPInt &A, const MPInt &B);
+  friend MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+  /// ---------------------------------------------------------------------------
+  /// Convenience operator overloads for int64_t.
+  /// ---------------------------------------------------------------------------
+  friend MPInt &operator+=(MPInt &A, int64_t B);
+  friend MPInt &operator-=(MPInt &A, int64_t B);
+  friend MPInt &operator*=(MPInt &A, int64_t B);
+  friend MPInt &operator/=(MPInt &A, int64_t B);
+  friend MPInt &operator%=(MPInt &A, int64_t B);
+
+  friend bool operator==(const MPInt &A, int64_t B);
+  friend bool operator!=(const MPInt &A, int64_t B);
+  friend bool operator>(const MPInt &A, int64_t B);
+  friend bool operator<(const MPInt &A, int64_t B);
+  friend bool operator<=(const MPInt &A, int64_t B);
+  friend bool operator>=(const MPInt &A, int64_t B);
+  friend MPInt operator+(const MPInt &A, int64_t B);
+  friend MPInt operator-(const MPInt &A, int64_t B);
+  friend MPInt operator*(const MPInt &A, int64_t B);
+  friend MPInt operator/(const MPInt &A, int64_t B);
+  friend MPInt operator%(const MPInt &A, int64_t B);
+
+  friend bool operator==(int64_t A, const MPInt &B);
+  friend bool operator!=(int64_t A, const MPInt &B);
+  friend bool operator>(int64_t A, const MPInt &B);
+  friend bool operator<(int64_t A, const MPInt &B);
+  friend bool operator<=(int64_t A, const MPInt &B);
+  friend bool operator>=(int64_t A, const MPInt &B);
+  friend MPInt operator+(int64_t A, const MPInt &B);
+  friend MPInt operator-(int64_t A, const MPInt &B);
+  friend MPInt operator*(int64_t A, const MPInt &B);
+  friend MPInt operator/(int64_t A, const MPInt &B);
+  friend MPInt operator%(int64_t A, const MPInt &B);
+
+  friend hash_code hash_value(const MPInt &x); // NOLINT
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  raw_ostream &print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+inline raw_ostream &operator<<(raw_ostream &OS, const MPInt &X) {
+  X.print(OS);
+  return OS;
+}
+#endif
+
+/// Redeclarations of friend declaration above to
+/// make it discoverable by lookups.
+hash_code hash_value(const MPInt &X); // NOLINT
+
+/// This just calls through to the operator int64_t, but it's useful when a
+/// function pointer is required. (Although this is marked inline, it is still
+/// possible to obtain and use a function pointer to this.)
+static inline int64_t int64FromMPInt(const MPInt &X) { return int64_t(X); }
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mpintFromInt64(int64_t X) {
+  return MPInt(X);
+}
+
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+namespace detail {
+// Division overflows only when trying to negate the minimal signed value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool divWouldOverflow(int64_t X, int64_t Y) {
+  return X == std::numeric_limits<int64_t>::min() && Y == -1;
+}
+} // namespace detail
+
+/// We define the operations here in the header to facilitate inlining.
+
+/// ---------------------------------------------------------------------------
+/// Comparison operators.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator==(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() == O.getSmall();
+  return detail::SlowMPInt(*this) == detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator!=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() != O.getSmall();
+  return detail::SlowMPInt(*this) != detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() > O.getSmall();
+  return detail::SlowMPInt(*this) > detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() < O.getSmall();
+  return detail::SlowMPInt(*this) < detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() <= O.getSmall();
+  return detail::SlowMPInt(*this) <= detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() >= O.getSmall();
+  return detail::SlowMPInt(*this) >= detail::SlowMPInt(O);
+}
+
+/// ---------------------------------------------------------------------------
+/// Arithmetic operators.
+/// ---------------------------------------------------------------------------
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator+(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::addOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::subOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator*(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::mulOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+}
+
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::divByPositive(const MPInt &O) const {
+  assert(O > 0);
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() / O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator/(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    // Division overflows only occur when negating the minimal possible value.
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), O.getSmall())))
+      return -*this;
+    return MPInt(getSmall() / O.getSmall());
+  }
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &X) {
+  return MPInt(X >= 0 ? X : -X);
+}
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::ceilDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(ceilDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt floorDiv(const MPInt &LHS,
+                                            const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::floorDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(floorDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall()))
+    return MPInt(detail::mod(LHS.getSmall(), RHS.getSmall()));
+  return MPInt(mod(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &A, const MPInt &B) {
+  assert(A >= 0 && B >= 0 && "operands must be non-negative!");
+  if (LLVM_LIKELY(A.isSmall() && B.isSmall()))
+    return MPInt(std::gcd(A.getSmall(), B.getSmall()));
+  return MPInt(gcd(detail::SlowMPInt(A), detail::SlowMPInt(B)));
+}
+
+/// Returns the least common multiple of A and B.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt lcm(const MPInt &A, const MPInt &B) {
+  MPInt X = abs(A);
+  MPInt Y = abs(B);
+  return (X * Y) / gcd(X, Y);
+}
+
+/// This operation cannot overflow.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator%(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() % O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) % detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-() const {
+  if (LLVM_LIKELY(isSmall())) {
+    if (LLVM_LIKELY(getSmall() != std::numeric_limits<int64_t>::min()))
+      return MPInt(-getSmall());
+    return MPInt(-detail::SlowMPInt(*this));
+  }
+  return MPInt(-detail::SlowMPInt(*this));
+}
+
+/// ---------------------------------------------------------------------------
+/// Assignment operators, preincrement, predecrement.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator+=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::addOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator-=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::subOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator*=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::mulOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly require...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 10, 2024

@llvm/pr-subscribers-llvm-adt

Author: Ramkumar Ramachandra (artagnon)

Changes

MPInt is an arbitrary-precision integer library that builds on top of APInt, and has a fast-path when the number fits within 64 bits. It was originally written for the Presburger library in MLIR, but seems useful to the LLVM project in general, independently of the Presburger library or MLIR. Hence, move it into LLVM/ADT.

This patch is part of a project to move the Presburger library into LLVM.


Patch is 107.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94953.diff

26 Files Affected:

  • (added) llvm/include/llvm/ADT/MPInt.h (+644)
  • (added) llvm/include/llvm/ADT/SlowMPInt.h (+138)
  • (modified) llvm/lib/Support/CMakeLists.txt (+2)
  • (added) llvm/lib/Support/MPInt.cpp (+25)
  • (added) llvm/lib/Support/SlowMPInt.cpp (+276)
  • (modified) llvm/unittests/ADT/CMakeLists.txt (+1)
  • (added) llvm/unittests/ADT/MPIntTest.cpp (+200)
  • (modified) mlir/include/mlir/Analysis/Presburger/Fraction.h (+2-2)
  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+3)
  • (removed) mlir/include/mlir/Analysis/Presburger/MPInt.h (-617)
  • (removed) mlir/include/mlir/Analysis/Presburger/SlowMPInt.h (-136)
  • (modified) mlir/include/mlir/Analysis/Presburger/Utils.h (+1-2)
  • (modified) mlir/include/mlir/Support/LLVM.h (+2)
  • (modified) mlir/lib/Analysis/Presburger/CMakeLists.txt (-2)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+4-4)
  • (modified) mlir/lib/Analysis/Presburger/LinearTransform.cpp (+1-1)
  • (removed) mlir/lib/Analysis/Presburger/MPInt.cpp (-38)
  • (modified) mlir/lib/Analysis/Presburger/Matrix.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PWMAFunction.cpp (+1-1)
  • (modified) mlir/lib/Analysis/Presburger/PresburgerRelation.cpp (-1)
  • (modified) mlir/lib/Analysis/Presburger/Simplex.cpp (+2-2)
  • (removed) mlir/lib/Analysis/Presburger/SlowMPInt.cpp (-290)
  • (modified) mlir/lib/Analysis/Presburger/Utils.cpp (+6-5)
  • (modified) mlir/unittests/Analysis/Presburger/CMakeLists.txt (-1)
  • (removed) mlir/unittests/Analysis/Presburger/MPIntTest.cpp (-200)
  • (modified) mlir/unittests/Analysis/Presburger/Utils.h (+2)
diff --git a/llvm/include/llvm/ADT/MPInt.h b/llvm/include/llvm/ADT/MPInt.h
new file mode 100644
index 0000000000000..dc387d7d0e5db
--- /dev/null
+++ b/llvm/include/llvm/ADT/MPInt.h
@@ -0,0 +1,644 @@
+//===- MPInt.h - MPInt Class ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a simple class to represent arbitrary precision signed integers.
+// Unlike APInt, one does not have to specify a fixed maximum size, and the
+// integer can take on any arbitrary values. This is optimized for small-values
+// by providing fast-paths for the cases when the value stored fits in 64-bits.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_MPINT_H
+#define LLVM_ADT_MPINT_H
+
+#include "llvm/ADT/SlowMPInt.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+namespace llvm {
+namespace detail {
+/// ---------------------------------------------------------------------------
+/// Some helpers from MLIR/MathExtras.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t ceilDiv(int64_t Numerator,
+                                             int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? ((Numerator + X) / Denominator) + 1
+                  : -(-Numerator / Denominator);
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t floorDiv(int64_t Numerator,
+                                              int64_t Denominator) {
+  assert(Denominator);
+  if (!Numerator)
+    return 0;
+  // C's integer division rounds towards 0.
+  int64_t X = (Denominator > 0) ? -1 : 1;
+  bool SameSign = (Numerator > 0) == (Denominator > 0);
+  return SameSign ? Numerator / Denominator
+                  : -((-Numerator + X) / Denominator) - 1;
+}
+
+/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
+/// always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t mod(int64_t Numerator,
+                                         int64_t Denominator) {
+  assert(Denominator >= 1);
+  return Numerator % Denominator < 0 ? Numerator % Denominator + Denominator
+                                     : Numerator % Denominator;
+}
+
+/// If builtin intrinsics for overflow-checked arithmetic are available,
+/// use them. Otherwise, call through to LLVM's overflow-checked arithmetic
+/// functionality. Those functions also have such macro-gated uses of intrinsics
+/// but they are not always_inlined, which is important for us to achieve
+/// high-performance; calling the functions directly would result in a slowdown
+/// of 1.15x.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_add_overflow)
+  return __builtin_add_overflow(X, Y, &Result);
+#else
+  return AddOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool subOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_sub_overflow)
+  return __builtin_sub_overflow(X, Y, &Result);
+#else
+  return SubOverflow(x, y, result);
+#endif
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool mulOverflow(int64_t X, int64_t Y,
+                                              int64_t &Result) {
+#if __has_builtin(__builtin_mul_overflow)
+  return __builtin_mul_overflow(X, Y, &Result);
+#else
+  return MulOverflow(x, y, result);
+#endif
+}
+} // namespace detail
+
+/// This class provides support for multi-precision arithmetic.
+///
+/// Unlike APInt, this extends the precision as necessary to prevent overflows
+/// and supports operations between objects with differing internal precisions.
+///
+/// This is optimized for small-values by providing fast-paths for the cases
+/// when the value stored fits in 64-bits. We annotate all fastpaths by using
+/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in
+/// a 1.2x performance slowdown.
+///
+/// We always_inline all operations; removing these results in a 1.5x
+/// performance slowdown.
+///
+/// When holdsLarge is true, a SlowMPInt is held in the union. If it is false,
+/// the int64_t is held. Using std::variant instead would lead to significantly
+/// worse performance.
+class MPInt {
+private:
+  union {
+    int64_t ValSmall;
+    detail::SlowMPInt ValLarge;
+  };
+  unsigned HoldsLarge;
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initSmall(int64_t O) {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+    ValSmall = O;
+    HoldsLarge = false;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void initLarge(const detail::SlowMPInt &O) {
+    if (LLVM_LIKELY(isSmall())) {
+      // The data in memory could be in an arbitrary state, not necessarily
+      // corresponding to any valid state of ValLarge; we cannot call any member
+      // functions, e.g. the assignment operator on it, as they may access the
+      // invalid internal state. We instead construct a new object using
+      // placement new.
+      new (&ValLarge) detail::SlowMPInt(O);
+    } else {
+      // In this case, we need to use the assignment operator, because if we use
+      // placement-new as above we would lose track of allocated memory
+      // and leak it.
+      ValLarge = O;
+    }
+    HoldsLarge = true;
+  }
+
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(const detail::SlowMPInt &Val)
+      : ValLarge(Val), HoldsLarge(true) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isSmall() const { return !HoldsLarge; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isLarge() const { return HoldsLarge; }
+  /// Get the stored value. For getSmall/Large,
+  /// the stored value should be small/large.
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t getSmall() const {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t &getSmall() {
+    assert(isSmall() &&
+           "getSmall should only be called when the value stored is small!");
+    return ValSmall;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE const detail::SlowMPInt &getLarge() const {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE detail::SlowMPInt &getLarge() {
+    assert(isLarge() &&
+           "getLarge should only be called when the value stored is large!");
+    return ValLarge;
+  }
+  explicit operator detail::SlowMPInt() const {
+    if (isSmall())
+      return detail::SlowMPInt(getSmall());
+    return getLarge();
+  }
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(int64_t Val)
+      : ValSmall(Val), HoldsLarge(false) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt() : MPInt(0) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE ~MPInt() {
+    if (LLVM_UNLIKELY(isLarge()))
+      ValLarge.detail::SlowMPInt::~SlowMPInt();
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt(const MPInt &O)
+      : ValSmall(O.ValSmall), HoldsLarge(false) {
+    if (LLVM_UNLIKELY(O.isLarge()))
+      initLarge(O.ValLarge);
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(const MPInt &O) {
+    if (LLVM_LIKELY(O.isSmall())) {
+      initSmall(O.ValSmall);
+      return *this;
+    }
+    initLarge(O.ValLarge);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(int X) {
+    initSmall(X);
+    return *this;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE explicit operator int64_t() const {
+    if (isSmall())
+      return getSmall();
+    return static_cast<int64_t>(getLarge());
+  }
+
+  bool operator==(const MPInt &O) const;
+  bool operator!=(const MPInt &O) const;
+  bool operator>(const MPInt &O) const;
+  bool operator<(const MPInt &O) const;
+  bool operator<=(const MPInt &O) const;
+  bool operator>=(const MPInt &O) const;
+  MPInt operator+(const MPInt &O) const;
+  MPInt operator-(const MPInt &O) const;
+  MPInt operator*(const MPInt &O) const;
+  MPInt operator/(const MPInt &O) const;
+  MPInt operator%(const MPInt &O) const;
+  MPInt &operator+=(const MPInt &O);
+  MPInt &operator-=(const MPInt &O);
+  MPInt &operator*=(const MPInt &O);
+  MPInt &operator/=(const MPInt &O);
+  MPInt &operator%=(const MPInt &O);
+  MPInt operator-() const;
+  MPInt &operator++();
+  MPInt &operator--();
+
+  // Divide by a number that is known to be positive.
+  // This is slightly more efficient because it saves an overflow check.
+  MPInt divByPositive(const MPInt &O) const;
+  MPInt &divByPositiveInPlace(const MPInt &O);
+
+  friend MPInt abs(const MPInt &X);
+  friend MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS);
+  friend MPInt floorDiv(const MPInt &LHS, const MPInt &RHS);
+  // The operands must be non-negative for gcd.
+  friend MPInt gcd(const MPInt &A, const MPInt &B);
+  friend MPInt lcm(const MPInt &A, const MPInt &B);
+  friend MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+  /// ---------------------------------------------------------------------------
+  /// Convenience operator overloads for int64_t.
+  /// ---------------------------------------------------------------------------
+  friend MPInt &operator+=(MPInt &A, int64_t B);
+  friend MPInt &operator-=(MPInt &A, int64_t B);
+  friend MPInt &operator*=(MPInt &A, int64_t B);
+  friend MPInt &operator/=(MPInt &A, int64_t B);
+  friend MPInt &operator%=(MPInt &A, int64_t B);
+
+  friend bool operator==(const MPInt &A, int64_t B);
+  friend bool operator!=(const MPInt &A, int64_t B);
+  friend bool operator>(const MPInt &A, int64_t B);
+  friend bool operator<(const MPInt &A, int64_t B);
+  friend bool operator<=(const MPInt &A, int64_t B);
+  friend bool operator>=(const MPInt &A, int64_t B);
+  friend MPInt operator+(const MPInt &A, int64_t B);
+  friend MPInt operator-(const MPInt &A, int64_t B);
+  friend MPInt operator*(const MPInt &A, int64_t B);
+  friend MPInt operator/(const MPInt &A, int64_t B);
+  friend MPInt operator%(const MPInt &A, int64_t B);
+
+  friend bool operator==(int64_t A, const MPInt &B);
+  friend bool operator!=(int64_t A, const MPInt &B);
+  friend bool operator>(int64_t A, const MPInt &B);
+  friend bool operator<(int64_t A, const MPInt &B);
+  friend bool operator<=(int64_t A, const MPInt &B);
+  friend bool operator>=(int64_t A, const MPInt &B);
+  friend MPInt operator+(int64_t A, const MPInt &B);
+  friend MPInt operator-(int64_t A, const MPInt &B);
+  friend MPInt operator*(int64_t A, const MPInt &B);
+  friend MPInt operator/(int64_t A, const MPInt &B);
+  friend MPInt operator%(int64_t A, const MPInt &B);
+
+  friend hash_code hash_value(const MPInt &x); // NOLINT
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  raw_ostream &print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+inline raw_ostream &operator<<(raw_ostream &OS, const MPInt &X) {
+  X.print(OS);
+  return OS;
+}
+#endif
+
+/// Redeclarations of friend declaration above to
+/// make it discoverable by lookups.
+hash_code hash_value(const MPInt &X); // NOLINT
+
+/// This just calls through to the operator int64_t, but it's useful when a
+/// function pointer is required. (Although this is marked inline, it is still
+/// possible to obtain and use a function pointer to this.)
+static inline int64_t int64FromMPInt(const MPInt &X) { return int64_t(X); }
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mpintFromInt64(int64_t X) {
+  return MPInt(X);
+}
+
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS);
+
+namespace detail {
+// Division overflows only when trying to negate the minimal signed value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool divWouldOverflow(int64_t X, int64_t Y) {
+  return X == std::numeric_limits<int64_t>::min() && Y == -1;
+}
+} // namespace detail
+
+/// We define the operations here in the header to facilitate inlining.
+
+/// ---------------------------------------------------------------------------
+/// Comparison operators.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator==(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() == O.getSmall();
+  return detail::SlowMPInt(*this) == detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator!=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() != O.getSmall();
+  return detail::SlowMPInt(*this) != detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() > O.getSmall();
+  return detail::SlowMPInt(*this) > detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() < O.getSmall();
+  return detail::SlowMPInt(*this) < detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() <= O.getSmall();
+  return detail::SlowMPInt(*this) <= detail::SlowMPInt(O);
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>=(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return getSmall() >= O.getSmall();
+  return detail::SlowMPInt(*this) >= detail::SlowMPInt(O);
+}
+
+/// ---------------------------------------------------------------------------
+/// Arithmetic operators.
+/// ---------------------------------------------------------------------------
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator+(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::addOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::subOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator*(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    MPInt Result;
+    bool Overflow =
+        detail::mulOverflow(getSmall(), O.getSmall(), Result.getSmall());
+    if (LLVM_LIKELY(!Overflow))
+      return Result;
+    return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+  }
+  return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(O));
+}
+
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::divByPositive(const MPInt &O) const {
+  assert(O > 0);
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() / O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator/(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    // Division overflows only occur when negating the minimal possible value.
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), O.getSmall())))
+      return -*this;
+    return MPInt(getSmall() / O.getSmall());
+  }
+  return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &X) {
+  return MPInt(X >= 0 ? X : -X);
+}
+// Division overflows only occur when negating the minimal possible value.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt ceilDiv(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::ceilDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(ceilDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt floorDiv(const MPInt &LHS,
+                                            const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) {
+    if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall())))
+      return -LHS;
+    return MPInt(detail::floorDiv(LHS.getSmall(), RHS.getSmall()));
+  }
+  return MPInt(floorDiv(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+// The RHS is always expected to be positive, and the result
+/// is always non-negative.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &LHS, const MPInt &RHS) {
+  if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall()))
+    return MPInt(detail::mod(LHS.getSmall(), RHS.getSmall()));
+  return MPInt(mod(detail::SlowMPInt(LHS), detail::SlowMPInt(RHS)));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &A, const MPInt &B) {
+  assert(A >= 0 && B >= 0 && "operands must be non-negative!");
+  if (LLVM_LIKELY(A.isSmall() && B.isSmall()))
+    return MPInt(std::gcd(A.getSmall(), B.getSmall()));
+  return MPInt(gcd(detail::SlowMPInt(A), detail::SlowMPInt(B)));
+}
+
+/// Returns the least common multiple of A and B.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt lcm(const MPInt &A, const MPInt &B) {
+  MPInt X = abs(A);
+  MPInt Y = abs(B);
+  return (X * Y) / gcd(X, Y);
+}
+
+/// This operation cannot overflow.
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator%(const MPInt &O) const {
+  if (LLVM_LIKELY(isSmall() && O.isSmall()))
+    return MPInt(getSmall() % O.getSmall());
+  return MPInt(detail::SlowMPInt(*this) % detail::SlowMPInt(O));
+}
+
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-() const {
+  if (LLVM_LIKELY(isSmall())) {
+    if (LLVM_LIKELY(getSmall() != std::numeric_limits<int64_t>::min()))
+      return MPInt(-getSmall());
+    return MPInt(-detail::SlowMPInt(*this));
+  }
+  return MPInt(-detail::SlowMPInt(*this));
+}
+
+/// ---------------------------------------------------------------------------
+/// Assignment operators, preincrement, predecrement.
+/// ---------------------------------------------------------------------------
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator+=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::addOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator-=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::subOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly required but
+    // removing it leads to a performance regression.
+    return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+  }
+  return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(O));
+}
+LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator*=(const MPInt &O) {
+  if (LLVM_LIKELY(isSmall() && O.isSmall())) {
+    int64_t Result = getSmall();
+    bool Overflow = detail::mulOverflow(getSmall(), O.getSmall(), Result);
+    if (LLVM_LIKELY(!Overflow)) {
+      getSmall() = Result;
+      return *this;
+    }
+    // Note: this return is not strictly require...
[truncated]

@Groverkss Groverkss requested a review from kuhar June 10, 2024 10:55
@Groverkss
Copy link
Member

From the Presburger side, LGTM. This is great! I'm going to let the reviewers for ADT have the final say in this.

llvm/include/llvm/ADT/MPInt.h Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/MPInt.h Outdated Show resolved Hide resolved
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this to ADT sounds reasonable. Some high-level thoughts after a brief look:

  • The name MPInt isn't great, in that it does not give you any idea on what the difference between APInt and MPInt is supposed to be. Calling this DynamicAPInt would make it clearer that the prime distinction is the dynamic bit width.
  • The representation of MPInt seems somewhat non-optimal to me, in that it has the separate unsigned HoldsLarge field -- we could use APInt.BitWidth == 0 to indicate this and cut the structure by 8 bytes (I think? I don't think the APInt padding would get used with the current implementation). This probably wasn't an option with the class in MLIR, but if it's in ADT we can make it a friend of APInt and make use of internal implementation details.

But I guess it would make sense to move it first and do any implementation changes later. I'd still consider whether it might make sense to adjust the name as part of the move (but I also don't feel very strongly about that).

@tschuett
Copy link
Member

There are a lot of micro optimizations in this PR. I would prefer readable code.

@artagnon
Copy link
Contributor Author

There are a lot of micro optimizations in this PR. I would prefer readable code.

The micro-optimizations are necessitated by the Presburger library, which is very compute-heavy.

@artagnon artagnon changed the title mlir/Presburger/MPInt: move into LLVM/ADT mlir/Presburger/MPInt: move into LLVM/ADT/DynamicAPInt Jun 10, 2024
@artagnon
Copy link
Contributor Author

I've additionally consolidated mlir/MathExtras with llvm/MathExtras. This patch is now ready, pending #95046 for no performance regression.

@nikic
Copy link
Contributor

nikic commented Jun 11, 2024

I've additionally consolidated mlir/MathExtras with llvm/MathExtras. This patch is now ready, pending #95046 for no performance regression.

Please split this off into a separate PR.

@artagnon artagnon removed the request for review from fhahn June 11, 2024 10:05
@artagnon
Copy link
Contributor Author

Once #95087 is merged, I will post a rebase.

MPInt is an arbitrary-precision integer library that builds on top of
APInt, and has a fast-path when the number fits within 64 bits. It was
originally written for the Presburger library in MLIR, but seems useful
to the LLVM project in general, independently of the Presburger library
or MLIR. Hence, move it into LLVM/ADT.

This patch is part of a project to move the Presburger library into
LLVM.
@artagnon
Copy link
Contributor Author

Rebased. Ready to re-review.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I did not review the implementation in detail.

@artagnon artagnon merged commit 76030dc into llvm:main Jun 12, 2024
4 of 6 checks passed
@artagnon artagnon deleted the mpint-llvm branch June 12, 2024 08:19
@artagnon
Copy link
Contributor Author

Sorry, the Release build seems to be broken, due to operator<< being guarded. I'm currently working on a fix.

makslevental added a commit to makslevental/llvm-project that referenced this pull request Jun 12, 2024
makslevental added a commit that referenced this pull request Jun 12, 2024
@artagnon
Copy link
Contributor Author

artagnon commented Jul 2, 2024

  • The representation of MPInt seems somewhat non-optimal to me, in that it has the separate unsigned HoldsLarge field -- we could use APInt.BitWidth == 0 to indicate this and cut the structure by 8 bytes (I think? I don't think the APInt padding would get used with the current implementation). This probably wasn't an option with the class in MLIR, but if it's in ADT we can make it a friend of APInt and make use of internal implementation details.

Hi @nikic, I'm trying to fix this issue, and I have a question: before SlowDynamicAPInt or APInt are even constructed for the fast-case of DynamicAPInt, how can we be assured that Val.BitWidth == 0?

@Superty
Copy link
Member

Superty commented Jul 2, 2024

If they are friends, I believe you can reach inside and set that location to zero (even if the APInt object hasn't been constructed).

@artagnon
Copy link
Contributor Author

artagnon commented Jul 2, 2024

Thanks for the hint. I will attempt to post a patch shortly.

@nikic
Copy link
Contributor

nikic commented Jul 2, 2024

Changing the union to something like

union {
  struct {
    unsigned BitWidth;
    int64_t ValSmall;
  }
  detail::SlowDynamicAPInt ValLarge;
};

would probably also work.

@Superty
Copy link
Member

Superty commented Jul 2, 2024

I would suggest adding a unit test / static assert that explicitly checks that the layout of APInt is as expected (whichever solution you choose)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants