From cb5d1b52ad2b34698a5023c50da4f59c70e05539 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 12 Jun 2024 06:55:48 -0700 Subject: [PATCH] Revert #95218 and #94953 (#95244) --- llvm/include/llvm/ADT/DynamicAPInt.h | 640 ------------------ llvm/include/llvm/ADT/SlowDynamicAPInt.h | 140 ---- llvm/lib/Support/CMakeLists.txt | 2 - llvm/lib/Support/DynamicAPInt.cpp | 29 - llvm/lib/Support/SlowDynamicAPInt.cpp | 288 -------- llvm/unittests/ADT/CMakeLists.txt | 1 - llvm/unittests/ADT/DynamicAPIntTest.cpp | 200 ------ .../mlir/Analysis/Presburger/Barvinok.h | 2 +- .../mlir/Analysis/Presburger/Fraction.h | 52 +- .../Analysis/Presburger/IntegerRelation.h | 102 ++- .../Analysis/Presburger/LinearTransform.h | 6 +- mlir/include/mlir/Analysis/Presburger/MPInt.h | 611 +++++++++++++++++ .../include/mlir/Analysis/Presburger/Matrix.h | 26 +- .../mlir/Analysis/Presburger/PWMAFunction.h | 18 +- .../Analysis/Presburger/PresburgerRelation.h | 8 +- .../mlir/Analysis/Presburger/Simplex.h | 41 +- .../mlir/Analysis/Presburger/SlowMPInt.h | 135 ++++ mlir/include/mlir/Analysis/Presburger/Utils.h | 62 +- mlir/include/mlir/Support/LLVM.h | 2 - .../Analysis/FlatLinearValueConstraints.cpp | 3 +- mlir/lib/Analysis/Presburger/Barvinok.cpp | 6 +- mlir/lib/Analysis/Presburger/CMakeLists.txt | 2 + .../Analysis/Presburger/IntegerRelation.cpp | 192 +++--- .../Analysis/Presburger/LinearTransform.cpp | 13 +- mlir/lib/Analysis/Presburger/MPInt.cpp | 38 ++ mlir/lib/Analysis/Presburger/Matrix.cpp | 23 +- mlir/lib/Analysis/Presburger/PWMAFunction.cpp | 39 +- .../Presburger/PresburgerRelation.cpp | 58 +- mlir/lib/Analysis/Presburger/Simplex.cpp | 172 +++-- mlir/lib/Analysis/Presburger/SlowMPInt.cpp | 290 ++++++++ mlir/lib/Analysis/Presburger/Utils.cpp | 137 ++-- .../Analysis/Presburger/CMakeLists.txt | 1 + .../Analysis/Presburger/FractionTest.cpp | 2 +- .../Presburger/IntegerPolyhedronTest.cpp | 22 +- .../Presburger/LinearTransformTest.cpp | 3 +- .../Analysis/Presburger/MPIntTest.cpp | 200 ++++++ .../Analysis/Presburger/MatrixTest.cpp | 2 +- .../Analysis/Presburger/SimplexTest.cpp | 29 +- mlir/unittests/Analysis/Presburger/Utils.h | 20 +- .../Analysis/Presburger/UtilsTest.cpp | 47 +- 40 files changed, 1783 insertions(+), 1881 deletions(-) delete mode 100644 llvm/include/llvm/ADT/DynamicAPInt.h delete mode 100644 llvm/include/llvm/ADT/SlowDynamicAPInt.h delete mode 100644 llvm/lib/Support/DynamicAPInt.cpp delete mode 100644 llvm/lib/Support/SlowDynamicAPInt.cpp delete mode 100644 llvm/unittests/ADT/DynamicAPIntTest.cpp create mode 100644 mlir/include/mlir/Analysis/Presburger/MPInt.h create mode 100644 mlir/include/mlir/Analysis/Presburger/SlowMPInt.h create mode 100644 mlir/lib/Analysis/Presburger/MPInt.cpp create mode 100644 mlir/lib/Analysis/Presburger/SlowMPInt.cpp create mode 100644 mlir/unittests/Analysis/Presburger/MPIntTest.cpp diff --git a/llvm/include/llvm/ADT/DynamicAPInt.h b/llvm/include/llvm/ADT/DynamicAPInt.h deleted file mode 100644 index 78a0372e09c49..0000000000000 --- a/llvm/include/llvm/ADT/DynamicAPInt.h +++ /dev/null @@ -1,640 +0,0 @@ -//===- DynamicAPInt.h - DynamicAPInt Class ----------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is a simple class to represent arbitrary precision signed integers. -// Unlike APInt, one does not have to specify a fixed maximum size, and the -// integer can take on any arbitrary values. This is optimized for small-values -// by providing fast-paths for the cases when the value stored fits in 64-bits. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_ADT_DYNAMICAPINT_H -#define LLVM_ADT_DYNAMICAPINT_H - -#include "llvm/ADT/SlowDynamicAPInt.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_ostream.h" -#include - -namespace llvm { -/// This class provides support for dynamic arbitrary-precision arithmetic. -/// -/// Unlike APInt, this extends the precision as necessary to prevent overflows -/// and supports operations between objects with differing internal precisions. -/// -/// This is optimized for small-values by providing fast-paths for the cases -/// when the value stored fits in 64-bits. We annotate all fastpaths by using -/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in -/// a 1.2x performance slowdown. -/// -/// We always_inline all operations; removing these results in a 1.5x -/// performance slowdown. -/// -/// When HoldsLarge is true, a SlowMPInt is held in the union. If it is false, -/// the int64_t is held. Using std::variant instead would lead to significantly -/// worse performance. -class DynamicAPInt { - union { - int64_t ValSmall; - detail::SlowDynamicAPInt ValLarge; - }; - unsigned HoldsLarge; - - LLVM_ATTRIBUTE_ALWAYS_INLINE void initSmall(int64_t O) { - if (LLVM_UNLIKELY(isLarge())) - ValLarge.detail::SlowDynamicAPInt::~SlowDynamicAPInt(); - ValSmall = O; - HoldsLarge = false; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE void - initLarge(const detail::SlowDynamicAPInt &O) { - if (LLVM_LIKELY(isSmall())) { - // The data in memory could be in an arbitrary state, not necessarily - // corresponding to any valid state of ValLarge; we cannot call any member - // functions, e.g. the assignment operator on it, as they may access the - // invalid internal state. We instead construct a new object using - // placement new. - new (&ValLarge) detail::SlowDynamicAPInt(O); - } else { - // In this case, we need to use the assignment operator, because if we use - // placement-new as above we would lose track of allocated memory - // and leak it. - ValLarge = O; - } - HoldsLarge = true; - } - - LLVM_ATTRIBUTE_ALWAYS_INLINE explicit DynamicAPInt( - const detail::SlowDynamicAPInt &Val) - : ValLarge(Val), HoldsLarge(true) {} - LLVM_ATTRIBUTE_ALWAYS_INLINE bool isSmall() const { return !HoldsLarge; } - LLVM_ATTRIBUTE_ALWAYS_INLINE bool isLarge() const { return HoldsLarge; } - /// Get the stored value. For getSmall/Large, - /// the stored value should be small/large. - LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t getSmall() const { - assert(isSmall() && - "getSmall should only be called when the value stored is small!"); - return ValSmall; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t &getSmall() { - assert(isSmall() && - "getSmall should only be called when the value stored is small!"); - return ValSmall; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE const detail::SlowDynamicAPInt & - getLarge() const { - assert(isLarge() && - "getLarge should only be called when the value stored is large!"); - return ValLarge; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE detail::SlowDynamicAPInt &getLarge() { - assert(isLarge() && - "getLarge should only be called when the value stored is large!"); - return ValLarge; - } - explicit operator detail::SlowDynamicAPInt() const { - if (isSmall()) - return detail::SlowDynamicAPInt(getSmall()); - return getLarge(); - } - -public: - LLVM_ATTRIBUTE_ALWAYS_INLINE explicit DynamicAPInt(int64_t Val) - : ValSmall(Val), HoldsLarge(false) {} - LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt() : DynamicAPInt(0) {} - LLVM_ATTRIBUTE_ALWAYS_INLINE ~DynamicAPInt() { - if (LLVM_UNLIKELY(isLarge())) - ValLarge.detail::SlowDynamicAPInt::~SlowDynamicAPInt(); - } - LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt(const DynamicAPInt &O) - : ValSmall(O.ValSmall), HoldsLarge(false) { - if (LLVM_UNLIKELY(O.isLarge())) - initLarge(O.ValLarge); - } - LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator=(const DynamicAPInt &O) { - if (LLVM_LIKELY(O.isSmall())) { - initSmall(O.ValSmall); - return *this; - } - initLarge(O.ValLarge); - return *this; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator=(int X) { - initSmall(X); - return *this; - } - LLVM_ATTRIBUTE_ALWAYS_INLINE explicit operator int64_t() const { - if (isSmall()) - return getSmall(); - return static_cast(getLarge()); - } - - bool operator==(const DynamicAPInt &O) const; - bool operator!=(const DynamicAPInt &O) const; - bool operator>(const DynamicAPInt &O) const; - bool operator<(const DynamicAPInt &O) const; - bool operator<=(const DynamicAPInt &O) const; - bool operator>=(const DynamicAPInt &O) const; - DynamicAPInt operator+(const DynamicAPInt &O) const; - DynamicAPInt operator-(const DynamicAPInt &O) const; - DynamicAPInt operator*(const DynamicAPInt &O) const; - DynamicAPInt operator/(const DynamicAPInt &O) const; - DynamicAPInt operator%(const DynamicAPInt &O) const; - DynamicAPInt &operator+=(const DynamicAPInt &O); - DynamicAPInt &operator-=(const DynamicAPInt &O); - DynamicAPInt &operator*=(const DynamicAPInt &O); - DynamicAPInt &operator/=(const DynamicAPInt &O); - DynamicAPInt &operator%=(const DynamicAPInt &O); - DynamicAPInt operator-() const; - DynamicAPInt &operator++(); - DynamicAPInt &operator--(); - - // Divide by a number that is known to be positive. - // This is slightly more efficient because it saves an overflow check. - DynamicAPInt divByPositive(const DynamicAPInt &O) const; - DynamicAPInt &divByPositiveInPlace(const DynamicAPInt &O); - - friend DynamicAPInt abs(const DynamicAPInt &X); - friend DynamicAPInt ceilDiv(const DynamicAPInt &LHS, const DynamicAPInt &RHS); - friend DynamicAPInt floorDiv(const DynamicAPInt &LHS, - const DynamicAPInt &RHS); - // The operands must be non-negative for gcd. - friend DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B); - friend DynamicAPInt lcm(const DynamicAPInt &A, const DynamicAPInt &B); - friend DynamicAPInt mod(const DynamicAPInt &LHS, const DynamicAPInt &RHS); - - /// --------------------------------------------------------------------------- - /// Convenience operator overloads for int64_t. - /// --------------------------------------------------------------------------- - friend DynamicAPInt &operator+=(DynamicAPInt &A, int64_t B); - friend DynamicAPInt &operator-=(DynamicAPInt &A, int64_t B); - friend DynamicAPInt &operator*=(DynamicAPInt &A, int64_t B); - friend DynamicAPInt &operator/=(DynamicAPInt &A, int64_t B); - friend DynamicAPInt &operator%=(DynamicAPInt &A, int64_t B); - - friend bool operator==(const DynamicAPInt &A, int64_t B); - friend bool operator!=(const DynamicAPInt &A, int64_t B); - friend bool operator>(const DynamicAPInt &A, int64_t B); - friend bool operator<(const DynamicAPInt &A, int64_t B); - friend bool operator<=(const DynamicAPInt &A, int64_t B); - friend bool operator>=(const DynamicAPInt &A, int64_t B); - friend DynamicAPInt operator+(const DynamicAPInt &A, int64_t B); - friend DynamicAPInt operator-(const DynamicAPInt &A, int64_t B); - friend DynamicAPInt operator*(const DynamicAPInt &A, int64_t B); - friend DynamicAPInt operator/(const DynamicAPInt &A, int64_t B); - friend DynamicAPInt operator%(const DynamicAPInt &A, int64_t B); - - friend bool operator==(int64_t A, const DynamicAPInt &B); - friend bool operator!=(int64_t A, const DynamicAPInt &B); - friend bool operator>(int64_t A, const DynamicAPInt &B); - friend bool operator<(int64_t A, const DynamicAPInt &B); - friend bool operator<=(int64_t A, const DynamicAPInt &B); - friend bool operator>=(int64_t A, const DynamicAPInt &B); - friend DynamicAPInt operator+(int64_t A, const DynamicAPInt &B); - friend DynamicAPInt operator-(int64_t A, const DynamicAPInt &B); - friend DynamicAPInt operator*(int64_t A, const DynamicAPInt &B); - friend DynamicAPInt operator/(int64_t A, const DynamicAPInt &B); - friend DynamicAPInt operator%(int64_t A, const DynamicAPInt &B); - - friend hash_code hash_value(const DynamicAPInt &x); // NOLINT - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - raw_ostream &print(raw_ostream &OS) const; - LLVM_DUMP_METHOD void dump() const; -#endif -}; - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -inline raw_ostream &operator<<(raw_ostream &OS, const DynamicAPInt &X) { - X.print(OS); - return OS; -} -#endif - -/// Redeclarations of friend declaration above to -/// make it discoverable by lookups. -hash_code hash_value(const DynamicAPInt &X); // NOLINT - -/// This just calls through to the operator int64_t, but it's useful when a -/// function pointer is required. (Although this is marked inline, it is still -/// possible to obtain and use a function pointer to this.) -static inline int64_t int64fromDynamicAPInt(const DynamicAPInt &X) { - return int64_t(X); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt dynamicAPIntFromInt64(int64_t X) { - return DynamicAPInt(X); -} - -// The RHS is always expected to be positive, and the result -/// is always non-negative. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt mod(const DynamicAPInt &LHS, - const DynamicAPInt &RHS); - -namespace detail { -// Division overflows only when trying to negate the minimal signed value. -LLVM_ATTRIBUTE_ALWAYS_INLINE bool divWouldOverflow(int64_t X, int64_t Y) { - return X == std::numeric_limits::min() && Y == -1; -} -} // namespace detail - -/// We define the operations here in the header to facilitate inlining. - -/// --------------------------------------------------------------------------- -/// Comparison operators. -/// --------------------------------------------------------------------------- -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator==(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() == O.getSmall(); - return detail::SlowDynamicAPInt(*this) == detail::SlowDynamicAPInt(O); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator!=(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() != O.getSmall(); - return detail::SlowDynamicAPInt(*this) != detail::SlowDynamicAPInt(O); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator>(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() > O.getSmall(); - return detail::SlowDynamicAPInt(*this) > detail::SlowDynamicAPInt(O); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator<(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() < O.getSmall(); - return detail::SlowDynamicAPInt(*this) < detail::SlowDynamicAPInt(O); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator<=(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() <= O.getSmall(); - return detail::SlowDynamicAPInt(*this) <= detail::SlowDynamicAPInt(O); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool -DynamicAPInt::operator>=(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return getSmall() >= O.getSmall(); - return detail::SlowDynamicAPInt(*this) >= detail::SlowDynamicAPInt(O); -} - -/// --------------------------------------------------------------------------- -/// Arithmetic operators. -/// --------------------------------------------------------------------------- - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::operator+(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - DynamicAPInt Result; - bool Overflow = AddOverflow(getSmall(), O.getSmall(), Result.getSmall()); - if (LLVM_LIKELY(!Overflow)) - return Result; - return DynamicAPInt(detail::SlowDynamicAPInt(*this) + - detail::SlowDynamicAPInt(O)); - } - return DynamicAPInt(detail::SlowDynamicAPInt(*this) + - detail::SlowDynamicAPInt(O)); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::operator-(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - DynamicAPInt Result; - bool Overflow = SubOverflow(getSmall(), O.getSmall(), Result.getSmall()); - if (LLVM_LIKELY(!Overflow)) - return Result; - return DynamicAPInt(detail::SlowDynamicAPInt(*this) - - detail::SlowDynamicAPInt(O)); - } - return DynamicAPInt(detail::SlowDynamicAPInt(*this) - - detail::SlowDynamicAPInt(O)); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::operator*(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - DynamicAPInt Result; - bool Overflow = MulOverflow(getSmall(), O.getSmall(), Result.getSmall()); - if (LLVM_LIKELY(!Overflow)) - return Result; - return DynamicAPInt(detail::SlowDynamicAPInt(*this) * - detail::SlowDynamicAPInt(O)); - } - return DynamicAPInt(detail::SlowDynamicAPInt(*this) * - detail::SlowDynamicAPInt(O)); -} - -// Division overflows only occur when negating the minimal possible value. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::divByPositive(const DynamicAPInt &O) const { - assert(O > 0); - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return DynamicAPInt(getSmall() / O.getSmall()); - return DynamicAPInt(detail::SlowDynamicAPInt(*this) / - detail::SlowDynamicAPInt(O)); -} - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::operator/(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - // Division overflows only occur when negating the minimal possible value. - if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), O.getSmall()))) - return -*this; - return DynamicAPInt(getSmall() / O.getSmall()); - } - return DynamicAPInt(detail::SlowDynamicAPInt(*this) / - detail::SlowDynamicAPInt(O)); -} - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt abs(const DynamicAPInt &X) { - return DynamicAPInt(X >= 0 ? X : -X); -} -// Division overflows only occur when negating the minimal possible value. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt ceilDiv(const DynamicAPInt &LHS, - const DynamicAPInt &RHS) { - if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) { - if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall()))) - return -LHS; - return DynamicAPInt(divideCeilSigned(LHS.getSmall(), RHS.getSmall())); - } - return DynamicAPInt( - ceilDiv(detail::SlowDynamicAPInt(LHS), detail::SlowDynamicAPInt(RHS))); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt floorDiv(const DynamicAPInt &LHS, - const DynamicAPInt &RHS) { - if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) { - if (LLVM_UNLIKELY(detail::divWouldOverflow(LHS.getSmall(), RHS.getSmall()))) - return -LHS; - return DynamicAPInt(divideFloorSigned(LHS.getSmall(), RHS.getSmall())); - } - return DynamicAPInt( - floorDiv(detail::SlowDynamicAPInt(LHS), detail::SlowDynamicAPInt(RHS))); -} -// The RHS is always expected to be positive, and the result -/// is always non-negative. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt mod(const DynamicAPInt &LHS, - const DynamicAPInt &RHS) { - if (LLVM_LIKELY(LHS.isSmall() && RHS.isSmall())) - return DynamicAPInt(mod(LHS.getSmall(), RHS.getSmall())); - return DynamicAPInt( - mod(detail::SlowDynamicAPInt(LHS), detail::SlowDynamicAPInt(RHS))); -} - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, - const DynamicAPInt &B) { - assert(A >= 0 && B >= 0 && "operands must be non-negative!"); - if (LLVM_LIKELY(A.isSmall() && B.isSmall())) - return DynamicAPInt(std::gcd(A.getSmall(), B.getSmall())); - return DynamicAPInt( - gcd(detail::SlowDynamicAPInt(A), detail::SlowDynamicAPInt(B))); -} - -/// Returns the least common multiple of A and B. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt lcm(const DynamicAPInt &A, - const DynamicAPInt &B) { - DynamicAPInt X = abs(A); - DynamicAPInt Y = abs(B); - return (X * Y) / gcd(X, Y); -} - -/// This operation cannot overflow. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt -DynamicAPInt::operator%(const DynamicAPInt &O) const { - if (LLVM_LIKELY(isSmall() && O.isSmall())) - return DynamicAPInt(getSmall() % O.getSmall()); - return DynamicAPInt(detail::SlowDynamicAPInt(*this) % - detail::SlowDynamicAPInt(O)); -} - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt DynamicAPInt::operator-() const { - if (LLVM_LIKELY(isSmall())) { - if (LLVM_LIKELY(getSmall() != std::numeric_limits::min())) - return DynamicAPInt(-getSmall()); - return DynamicAPInt(-detail::SlowDynamicAPInt(*this)); - } - return DynamicAPInt(-detail::SlowDynamicAPInt(*this)); -} - -/// --------------------------------------------------------------------------- -/// Assignment operators, preincrement, predecrement. -/// --------------------------------------------------------------------------- -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::operator+=(const DynamicAPInt &O) { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - int64_t Result = getSmall(); - bool Overflow = AddOverflow(getSmall(), O.getSmall(), Result); - if (LLVM_LIKELY(!Overflow)) { - getSmall() = Result; - return *this; - } - // Note: this return is not strictly required but - // removing it leads to a performance regression. - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) + - detail::SlowDynamicAPInt(O)); - } - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) + - detail::SlowDynamicAPInt(O)); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::operator-=(const DynamicAPInt &O) { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - int64_t Result = getSmall(); - bool Overflow = SubOverflow(getSmall(), O.getSmall(), Result); - if (LLVM_LIKELY(!Overflow)) { - getSmall() = Result; - return *this; - } - // Note: this return is not strictly required but - // removing it leads to a performance regression. - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) - - detail::SlowDynamicAPInt(O)); - } - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) - - detail::SlowDynamicAPInt(O)); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::operator*=(const DynamicAPInt &O) { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - int64_t Result = getSmall(); - bool Overflow = MulOverflow(getSmall(), O.getSmall(), Result); - if (LLVM_LIKELY(!Overflow)) { - getSmall() = Result; - return *this; - } - // Note: this return is not strictly required but - // removing it leads to a performance regression. - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) * - detail::SlowDynamicAPInt(O)); - } - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) * - detail::SlowDynamicAPInt(O)); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::operator/=(const DynamicAPInt &O) { - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - // Division overflows only occur when negating the minimal possible value. - if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), O.getSmall()))) - return *this = -*this; - getSmall() /= O.getSmall(); - return *this; - } - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) / - detail::SlowDynamicAPInt(O)); -} - -// Division overflows only occur when the divisor is -1. -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::divByPositiveInPlace(const DynamicAPInt &O) { - assert(O > 0); - if (LLVM_LIKELY(isSmall() && O.isSmall())) { - getSmall() /= O.getSmall(); - return *this; - } - return *this = DynamicAPInt(detail::SlowDynamicAPInt(*this) / - detail::SlowDynamicAPInt(O)); -} - -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt & -DynamicAPInt::operator%=(const DynamicAPInt &O) { - return *this = *this % O; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &DynamicAPInt::operator++() { - return *this += 1; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &DynamicAPInt::operator--() { - return *this -= 1; -} - -/// ---------------------------------------------------------------------------- -/// Convenience operator overloads for int64_t. -/// ---------------------------------------------------------------------------- -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator+=(DynamicAPInt &A, - int64_t B) { - return A = A + B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator-=(DynamicAPInt &A, - int64_t B) { - return A = A - B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator*=(DynamicAPInt &A, - int64_t B) { - return A = A * B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator/=(DynamicAPInt &A, - int64_t B) { - return A = A / B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt &operator%=(DynamicAPInt &A, - int64_t B) { - return A = A % B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator+(const DynamicAPInt &A, - int64_t B) { - return A + DynamicAPInt(B); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator-(const DynamicAPInt &A, - int64_t B) { - return A - DynamicAPInt(B); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator*(const DynamicAPInt &A, - int64_t B) { - return A * DynamicAPInt(B); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator/(const DynamicAPInt &A, - int64_t B) { - return A / DynamicAPInt(B); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator%(const DynamicAPInt &A, - int64_t B) { - return A % DynamicAPInt(B); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator+(int64_t A, - const DynamicAPInt &B) { - return DynamicAPInt(A) + B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator-(int64_t A, - const DynamicAPInt &B) { - return DynamicAPInt(A) - B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator*(int64_t A, - const DynamicAPInt &B) { - return DynamicAPInt(A) * B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator/(int64_t A, - const DynamicAPInt &B) { - return DynamicAPInt(A) / B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt operator%(int64_t A, - const DynamicAPInt &B) { - return DynamicAPInt(A) % B; -} - -/// We provide special implementations of the comparison operators rather than -/// calling through as above, as this would result in a 1.2x slowdown. -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator==(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() == B; - return A.getLarge() == B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator!=(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() != B; - return A.getLarge() != B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() > B; - return A.getLarge() > B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() < B; - return A.getLarge() < B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<=(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() <= B; - return A.getLarge() <= B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>=(const DynamicAPInt &A, int64_t B) { - if (LLVM_LIKELY(A.isSmall())) - return A.getSmall() >= B; - return A.getLarge() >= B; -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator==(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A == B.getSmall(); - return A == B.getLarge(); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator!=(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A != B.getSmall(); - return A != B.getLarge(); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A > B.getSmall(); - return A > B.getLarge(); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A < B.getSmall(); - return A < B.getLarge(); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<=(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A <= B.getSmall(); - return A <= B.getLarge(); -} -LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>=(int64_t A, const DynamicAPInt &B) { - if (LLVM_LIKELY(B.isSmall())) - return A >= B.getSmall(); - return A >= B.getLarge(); -} -} // namespace llvm - -#endif // LLVM_ADT_DYNAMICAPINT_H diff --git a/llvm/include/llvm/ADT/SlowDynamicAPInt.h b/llvm/include/llvm/ADT/SlowDynamicAPInt.h deleted file mode 100644 index 009deab6c6c92..0000000000000 --- a/llvm/include/llvm/ADT/SlowDynamicAPInt.h +++ /dev/null @@ -1,140 +0,0 @@ -//===- SlowDynamicAPInt.h - SlowDynamicAPInt Class --------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is a simple class to represent arbitrary precision signed integers. -// Unlike APInt, one does not have to specify a fixed maximum size, and the -// integer can take on any arbitrary values. -// -// This class is to be used as a fallback slow path for the DynamicAPInt class, -// and is not intended to be used directly. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_ADT_SLOWDYNAMICAPINT_H -#define LLVM_ADT_SLOWDYNAMICAPINT_H - -#include "llvm/ADT/APInt.h" -#include "llvm/Support/raw_ostream.h" - -namespace llvm::detail { -/// A simple class providing dynamic arbitrary-precision arithmetic. Internally, -/// it stores an APInt, whose width is doubled whenever an overflow occurs at a -/// certain width. The default constructor sets the initial width to 64. -/// SlowDynamicAPInt is primarily intended to be used as a slow fallback path -/// for the upcoming DynamicAPInt class. -class SlowDynamicAPInt { - APInt Val; - -public: - explicit SlowDynamicAPInt(int64_t Val); - SlowDynamicAPInt(); - explicit SlowDynamicAPInt(const APInt &Val); - SlowDynamicAPInt &operator=(int64_t Val); - explicit operator int64_t() const; - SlowDynamicAPInt operator-() const; - bool operator==(const SlowDynamicAPInt &O) const; - bool operator!=(const SlowDynamicAPInt &O) const; - bool operator>(const SlowDynamicAPInt &O) const; - bool operator<(const SlowDynamicAPInt &O) const; - bool operator<=(const SlowDynamicAPInt &O) const; - bool operator>=(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt operator+(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt operator-(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt operator*(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt operator/(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt operator%(const SlowDynamicAPInt &O) const; - SlowDynamicAPInt &operator+=(const SlowDynamicAPInt &O); - SlowDynamicAPInt &operator-=(const SlowDynamicAPInt &O); - SlowDynamicAPInt &operator*=(const SlowDynamicAPInt &O); - SlowDynamicAPInt &operator/=(const SlowDynamicAPInt &O); - SlowDynamicAPInt &operator%=(const SlowDynamicAPInt &O); - - SlowDynamicAPInt &operator++(); - SlowDynamicAPInt &operator--(); - - friend SlowDynamicAPInt abs(const SlowDynamicAPInt &X); - friend SlowDynamicAPInt ceilDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS); - friend SlowDynamicAPInt floorDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS); - /// The operands must be non-negative for gcd. - friend SlowDynamicAPInt gcd(const SlowDynamicAPInt &A, - const SlowDynamicAPInt &B); - - /// Overload to compute a hash_code for a SlowDynamicAPInt value. - friend hash_code hash_value(const SlowDynamicAPInt &X); // NOLINT - - unsigned getBitWidth() const { return Val.getBitWidth(); } - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - void print(raw_ostream &OS) const; - LLVM_DUMP_METHOD void dump() const; -#endif -}; - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -inline raw_ostream &operator<<(raw_ostream &OS, const SlowDynamicAPInt &X) { - X.print(OS); - return OS; -} -#endif - -/// Returns the remainder of dividing LHS by RHS. -/// -/// The RHS is always expected to be positive, and the result -/// is always non-negative. -SlowDynamicAPInt mod(const SlowDynamicAPInt &LHS, const SlowDynamicAPInt &RHS); - -/// Returns the least common multiple of A and B. -SlowDynamicAPInt lcm(const SlowDynamicAPInt &A, const SlowDynamicAPInt &B); - -/// Redeclarations of friend declarations above to -/// make it discoverable by lookups. -SlowDynamicAPInt abs(const SlowDynamicAPInt &X); -SlowDynamicAPInt ceilDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS); -SlowDynamicAPInt floorDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS); -SlowDynamicAPInt gcd(const SlowDynamicAPInt &A, const SlowDynamicAPInt &B); -hash_code hash_value(const SlowDynamicAPInt &X); // NOLINT - -/// --------------------------------------------------------------------------- -/// Convenience operator overloads for int64_t. -/// --------------------------------------------------------------------------- -SlowDynamicAPInt &operator+=(SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt &operator-=(SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt &operator*=(SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt &operator/=(SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt &operator%=(SlowDynamicAPInt &A, int64_t B); - -bool operator==(const SlowDynamicAPInt &A, int64_t B); -bool operator!=(const SlowDynamicAPInt &A, int64_t B); -bool operator>(const SlowDynamicAPInt &A, int64_t B); -bool operator<(const SlowDynamicAPInt &A, int64_t B); -bool operator<=(const SlowDynamicAPInt &A, int64_t B); -bool operator>=(const SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt operator+(const SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt operator-(const SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt operator*(const SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt operator/(const SlowDynamicAPInt &A, int64_t B); -SlowDynamicAPInt operator%(const SlowDynamicAPInt &A, int64_t B); - -bool operator==(int64_t A, const SlowDynamicAPInt &B); -bool operator!=(int64_t A, const SlowDynamicAPInt &B); -bool operator>(int64_t A, const SlowDynamicAPInt &B); -bool operator<(int64_t A, const SlowDynamicAPInt &B); -bool operator<=(int64_t A, const SlowDynamicAPInt &B); -bool operator>=(int64_t A, const SlowDynamicAPInt &B); -SlowDynamicAPInt operator+(int64_t A, const SlowDynamicAPInt &B); -SlowDynamicAPInt operator-(int64_t A, const SlowDynamicAPInt &B); -SlowDynamicAPInt operator*(int64_t A, const SlowDynamicAPInt &B); -SlowDynamicAPInt operator/(int64_t A, const SlowDynamicAPInt &B); -SlowDynamicAPInt operator%(int64_t A, const SlowDynamicAPInt &B); -} // namespace llvm::detail - -#endif // LLVM_ADT_SLOWDYNAMICAPINT_H diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index c7f8ac325a97a..5df36f811efe9 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -170,7 +170,6 @@ add_llvm_component_library(LLVMSupport DivisionByConstantInfo.cpp DAGDeltaAlgorithm.cpp DJB.cpp - DynamicAPInt.cpp ELFAttributeParser.cpp ELFAttributes.cpp Error.cpp @@ -224,7 +223,6 @@ add_llvm_component_library(LLVMSupport SHA1.cpp SHA256.cpp Signposts.cpp - SlowDynamicAPInt.cpp SmallPtrSet.cpp SmallVector.cpp SourceMgr.cpp diff --git a/llvm/lib/Support/DynamicAPInt.cpp b/llvm/lib/Support/DynamicAPInt.cpp deleted file mode 100644 index 6f7eecca36db0..0000000000000 --- a/llvm/lib/Support/DynamicAPInt.cpp +++ /dev/null @@ -1,29 +0,0 @@ -//===- DynamicAPInt.cpp - DynamicAPInt Implementation -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#include "llvm/ADT/DynamicAPInt.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -using namespace llvm; - -hash_code llvm::hash_value(const DynamicAPInt &X) { - if (X.isSmall()) - return llvm::hash_value(X.getSmall()); - return detail::hash_value(X.getLarge()); -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -raw_ostream &DynamicAPInt::print(raw_ostream &OS) const { - if (isSmall()) - return OS << ValSmall; - return OS << ValLarge; -} - -void DynamicAPInt::dump() const { print(dbgs()); } -#endif diff --git a/llvm/lib/Support/SlowDynamicAPInt.cpp b/llvm/lib/Support/SlowDynamicAPInt.cpp deleted file mode 100644 index 5d88cc53d17ba..0000000000000 --- a/llvm/lib/Support/SlowDynamicAPInt.cpp +++ /dev/null @@ -1,288 +0,0 @@ -//===- SlowDynamicAPInt.cpp - SlowDynamicAPInt Implementation -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SlowDynamicAPInt.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -using namespace llvm; -using namespace detail; - -SlowDynamicAPInt::SlowDynamicAPInt(int64_t Val) - : Val(64, Val, /*isSigned=*/true) {} -SlowDynamicAPInt::SlowDynamicAPInt() : SlowDynamicAPInt(0) {} -SlowDynamicAPInt::SlowDynamicAPInt(const APInt &Val) : Val(Val) {} -SlowDynamicAPInt &SlowDynamicAPInt::operator=(int64_t Val) { - return *this = SlowDynamicAPInt(Val); -} -SlowDynamicAPInt::operator int64_t() const { return Val.getSExtValue(); } - -hash_code detail::hash_value(const SlowDynamicAPInt &X) { - return hash_value(X.Val); -} - -/// --------------------------------------------------------------------------- -/// Convenience operator overloads for int64_t. -/// --------------------------------------------------------------------------- -SlowDynamicAPInt &detail::operator+=(SlowDynamicAPInt &A, int64_t B) { - return A += SlowDynamicAPInt(B); -} -SlowDynamicAPInt &detail::operator-=(SlowDynamicAPInt &A, int64_t B) { - return A -= SlowDynamicAPInt(B); -} -SlowDynamicAPInt &detail::operator*=(SlowDynamicAPInt &A, int64_t B) { - return A *= SlowDynamicAPInt(B); -} -SlowDynamicAPInt &detail::operator/=(SlowDynamicAPInt &A, int64_t B) { - return A /= SlowDynamicAPInt(B); -} -SlowDynamicAPInt &detail::operator%=(SlowDynamicAPInt &A, int64_t B) { - return A %= SlowDynamicAPInt(B); -} - -bool detail::operator==(const SlowDynamicAPInt &A, int64_t B) { - return A == SlowDynamicAPInt(B); -} -bool detail::operator!=(const SlowDynamicAPInt &A, int64_t B) { - return A != SlowDynamicAPInt(B); -} -bool detail::operator>(const SlowDynamicAPInt &A, int64_t B) { - return A > SlowDynamicAPInt(B); -} -bool detail::operator<(const SlowDynamicAPInt &A, int64_t B) { - return A < SlowDynamicAPInt(B); -} -bool detail::operator<=(const SlowDynamicAPInt &A, int64_t B) { - return A <= SlowDynamicAPInt(B); -} -bool detail::operator>=(const SlowDynamicAPInt &A, int64_t B) { - return A >= SlowDynamicAPInt(B); -} -SlowDynamicAPInt detail::operator+(const SlowDynamicAPInt &A, int64_t B) { - return A + SlowDynamicAPInt(B); -} -SlowDynamicAPInt detail::operator-(const SlowDynamicAPInt &A, int64_t B) { - return A - SlowDynamicAPInt(B); -} -SlowDynamicAPInt detail::operator*(const SlowDynamicAPInt &A, int64_t B) { - return A * SlowDynamicAPInt(B); -} -SlowDynamicAPInt detail::operator/(const SlowDynamicAPInt &A, int64_t B) { - return A / SlowDynamicAPInt(B); -} -SlowDynamicAPInt detail::operator%(const SlowDynamicAPInt &A, int64_t B) { - return A % SlowDynamicAPInt(B); -} - -bool detail::operator==(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) == B; -} -bool detail::operator!=(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) != B; -} -bool detail::operator>(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) > B; -} -bool detail::operator<(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) < B; -} -bool detail::operator<=(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) <= B; -} -bool detail::operator>=(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) >= B; -} -SlowDynamicAPInt detail::operator+(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) + B; -} -SlowDynamicAPInt detail::operator-(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) - B; -} -SlowDynamicAPInt detail::operator*(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) * B; -} -SlowDynamicAPInt detail::operator/(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) / B; -} -SlowDynamicAPInt detail::operator%(int64_t A, const SlowDynamicAPInt &B) { - return SlowDynamicAPInt(A) % B; -} - -static unsigned getMaxWidth(const APInt &A, const APInt &B) { - return std::max(A.getBitWidth(), B.getBitWidth()); -} - -/// --------------------------------------------------------------------------- -/// Comparison operators. -/// --------------------------------------------------------------------------- - -// TODO: consider instead making APInt::compare available and using that. -bool SlowDynamicAPInt::operator==(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width) == O.Val.sext(Width); -} -bool SlowDynamicAPInt::operator!=(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width) != O.Val.sext(Width); -} -bool SlowDynamicAPInt::operator>(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width).sgt(O.Val.sext(Width)); -} -bool SlowDynamicAPInt::operator<(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width).slt(O.Val.sext(Width)); -} -bool SlowDynamicAPInt::operator<=(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width).sle(O.Val.sext(Width)); -} -bool SlowDynamicAPInt::operator>=(const SlowDynamicAPInt &O) const { - unsigned Width = getMaxWidth(Val, O.Val); - return Val.sext(Width).sge(O.Val.sext(Width)); -} - -/// --------------------------------------------------------------------------- -/// Arithmetic operators. -/// --------------------------------------------------------------------------- - -/// Bring a and b to have the same width and then call op(a, b, overflow). -/// If the overflow bit becomes set, resize a and b to double the width and -/// call op(a, b, overflow), returning its result. The operation with double -/// widths should not also overflow. -APInt runOpWithExpandOnOverflow( - const APInt &A, const APInt &B, - function_ref Op) { - bool Overflow; - unsigned Width = getMaxWidth(A, B); - APInt Ret = Op(A.sext(Width), B.sext(Width), Overflow); - if (!Overflow) - return Ret; - - Width *= 2; - Ret = Op(A.sext(Width), B.sext(Width), Overflow); - assert(!Overflow && "double width should be sufficient to avoid overflow!"); - return Ret; -} - -SlowDynamicAPInt SlowDynamicAPInt::operator+(const SlowDynamicAPInt &O) const { - return SlowDynamicAPInt( - runOpWithExpandOnOverflow(Val, O.Val, std::mem_fn(&APInt::sadd_ov))); -} -SlowDynamicAPInt SlowDynamicAPInt::operator-(const SlowDynamicAPInt &O) const { - return SlowDynamicAPInt( - runOpWithExpandOnOverflow(Val, O.Val, std::mem_fn(&APInt::ssub_ov))); -} -SlowDynamicAPInt SlowDynamicAPInt::operator*(const SlowDynamicAPInt &O) const { - return SlowDynamicAPInt( - runOpWithExpandOnOverflow(Val, O.Val, std::mem_fn(&APInt::smul_ov))); -} -SlowDynamicAPInt SlowDynamicAPInt::operator/(const SlowDynamicAPInt &O) const { - return SlowDynamicAPInt( - runOpWithExpandOnOverflow(Val, O.Val, std::mem_fn(&APInt::sdiv_ov))); -} -SlowDynamicAPInt detail::abs(const SlowDynamicAPInt &X) { - return X >= 0 ? X : -X; -} -SlowDynamicAPInt detail::ceilDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS) { - if (RHS == -1) - return -LHS; - unsigned Width = getMaxWidth(LHS.Val, RHS.Val); - return SlowDynamicAPInt(APIntOps::RoundingSDiv( - LHS.Val.sext(Width), RHS.Val.sext(Width), APInt::Rounding::UP)); -} -SlowDynamicAPInt detail::floorDiv(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS) { - if (RHS == -1) - return -LHS; - unsigned Width = getMaxWidth(LHS.Val, RHS.Val); - return SlowDynamicAPInt(APIntOps::RoundingSDiv( - LHS.Val.sext(Width), RHS.Val.sext(Width), APInt::Rounding::DOWN)); -} -// The RHS is always expected to be positive, and the result -/// is always non-negative. -SlowDynamicAPInt detail::mod(const SlowDynamicAPInt &LHS, - const SlowDynamicAPInt &RHS) { - assert(RHS >= 1 && "mod is only supported for positive divisors!"); - return LHS % RHS < 0 ? LHS % RHS + RHS : LHS % RHS; -} - -SlowDynamicAPInt detail::gcd(const SlowDynamicAPInt &A, - const SlowDynamicAPInt &B) { - assert(A >= 0 && B >= 0 && "operands must be non-negative!"); - unsigned Width = getMaxWidth(A.Val, B.Val); - return SlowDynamicAPInt( - APIntOps::GreatestCommonDivisor(A.Val.sext(Width), B.Val.sext(Width))); -} - -/// Returns the least common multiple of A and B. -SlowDynamicAPInt detail::lcm(const SlowDynamicAPInt &A, - const SlowDynamicAPInt &B) { - SlowDynamicAPInt X = abs(A); - SlowDynamicAPInt Y = abs(B); - return (X * Y) / gcd(X, Y); -} - -/// This operation cannot overflow. -SlowDynamicAPInt SlowDynamicAPInt::operator%(const SlowDynamicAPInt &O) const { - unsigned Width = std::max(Val.getBitWidth(), O.Val.getBitWidth()); - return SlowDynamicAPInt(Val.sext(Width).srem(O.Val.sext(Width))); -} - -SlowDynamicAPInt SlowDynamicAPInt::operator-() const { - if (Val.isMinSignedValue()) { - /// Overflow only occurs when the value is the minimum possible value. - APInt Ret = Val.sext(2 * Val.getBitWidth()); - return SlowDynamicAPInt(-Ret); - } - return SlowDynamicAPInt(-Val); -} - -/// --------------------------------------------------------------------------- -/// Assignment operators, preincrement, predecrement. -/// --------------------------------------------------------------------------- -SlowDynamicAPInt &SlowDynamicAPInt::operator+=(const SlowDynamicAPInt &O) { - *this = *this + O; - return *this; -} -SlowDynamicAPInt &SlowDynamicAPInt::operator-=(const SlowDynamicAPInt &O) { - *this = *this - O; - return *this; -} -SlowDynamicAPInt &SlowDynamicAPInt::operator*=(const SlowDynamicAPInt &O) { - *this = *this * O; - return *this; -} -SlowDynamicAPInt &SlowDynamicAPInt::operator/=(const SlowDynamicAPInt &O) { - *this = *this / O; - return *this; -} -SlowDynamicAPInt &SlowDynamicAPInt::operator%=(const SlowDynamicAPInt &O) { - *this = *this % O; - return *this; -} -SlowDynamicAPInt &SlowDynamicAPInt::operator++() { - *this += 1; - return *this; -} - -SlowDynamicAPInt &SlowDynamicAPInt::operator--() { - *this -= 1; - return *this; -} - -/// --------------------------------------------------------------------------- -/// Printing. -/// --------------------------------------------------------------------------- -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void SlowDynamicAPInt::print(raw_ostream &OS) const { OS << Val; } - -void SlowDynamicAPInt::dump() const { print(dbgs()); } -#endif diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt index f48d840a10595..85c140e63fecd 100644 --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -25,7 +25,6 @@ add_llvm_unittest(ADTTests DenseSetTest.cpp DepthFirstIteratorTest.cpp DirectedGraphTest.cpp - DynamicAPIntTest.cpp EditDistanceTest.cpp EnumeratedArrayTest.cpp EquivalenceClassesTest.cpp diff --git a/llvm/unittests/ADT/DynamicAPIntTest.cpp b/llvm/unittests/ADT/DynamicAPIntTest.cpp deleted file mode 100644 index 932b750608b3e..0000000000000 --- a/llvm/unittests/ADT/DynamicAPIntTest.cpp +++ /dev/null @@ -1,200 +0,0 @@ -//===- MPIntTest.cpp - Tests for MPInt ------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/DynamicAPInt.h" -#include "llvm/ADT/SlowDynamicAPInt.h" -#include "gtest/gtest.h" - -using namespace llvm; - -namespace { -// googletest boilerplate to run the same tests with both MPInt and SlowMPInt. -template class IntTest : public testing::Test {}; -using TypeList = testing::Types; - -// This is for pretty-printing the test name with the name of the class in use. -class TypeNames { -public: - template - static std::string GetName(int) { // NOLINT; gtest mandates this name. - if (std::is_same()) - return "MPInt"; - if (std::is_same()) - return "SlowMPInt"; - llvm_unreachable("Unknown class!"); - } -}; -TYPED_TEST_SUITE(IntTest, TypeList, TypeNames); - -TYPED_TEST(IntTest, ops) { - TypeParam Two(2), Five(5), Seven(7), Ten(10); - EXPECT_EQ(Five + Five, Ten); - EXPECT_EQ(Five * Five, 2 * Ten + Five); - EXPECT_EQ(Five * Five, 3 * Ten - Five); - EXPECT_EQ(Five * Two, Ten); - EXPECT_EQ(Five / Two, Two); - EXPECT_EQ(Five % Two, Two / Two); - - EXPECT_EQ(-Ten % Seven, -10 % 7); - EXPECT_EQ(Ten % -Seven, 10 % -7); - EXPECT_EQ(-Ten % -Seven, -10 % -7); - EXPECT_EQ(Ten % Seven, 10 % 7); - - EXPECT_EQ(-Ten / Seven, -10 / 7); - EXPECT_EQ(Ten / -Seven, 10 / -7); - EXPECT_EQ(-Ten / -Seven, -10 / -7); - EXPECT_EQ(Ten / Seven, 10 / 7); - - TypeParam X = Ten; - X += Five; - EXPECT_EQ(X, 15); - X *= Two; - EXPECT_EQ(X, 30); - X /= Seven; - EXPECT_EQ(X, 4); - X -= Two * 10; - EXPECT_EQ(X, -16); - X *= 2 * Two; - EXPECT_EQ(X, -64); - X /= Two / -2; - EXPECT_EQ(X, 64); - - EXPECT_LE(Ten, Ten); - EXPECT_GE(Ten, Ten); - EXPECT_EQ(Ten, Ten); - EXPECT_FALSE(Ten != Ten); - EXPECT_FALSE(Ten < Ten); - EXPECT_FALSE(Ten > Ten); - EXPECT_LT(Five, Ten); - EXPECT_GT(Ten, Five); -} - -TYPED_TEST(IntTest, ops64Overloads) { - TypeParam Two(2), Five(5), Seven(7), Ten(10); - EXPECT_EQ(Five + 5, Ten); - EXPECT_EQ(Five + 5, 5 + Five); - EXPECT_EQ(Five * 5, 2 * Ten + 5); - EXPECT_EQ(Five * 5, 3 * Ten - 5); - EXPECT_EQ(Five * Two, Ten); - EXPECT_EQ(5 / Two, 2); - EXPECT_EQ(Five / 2, 2); - EXPECT_EQ(2 % Two, 0); - EXPECT_EQ(2 - Two, 0); - EXPECT_EQ(2 % Two, Two % 2); - - TypeParam X = Ten; - X += 5; - EXPECT_EQ(X, 15); - X *= 2; - EXPECT_EQ(X, 30); - X /= 7; - EXPECT_EQ(X, 4); - X -= 20; - EXPECT_EQ(X, -16); - X *= 4; - EXPECT_EQ(X, -64); - X /= -1; - EXPECT_EQ(X, 64); - - EXPECT_LE(Ten, 10); - EXPECT_GE(Ten, 10); - EXPECT_EQ(Ten, 10); - EXPECT_FALSE(Ten != 10); - EXPECT_FALSE(Ten < 10); - EXPECT_FALSE(Ten > 10); - EXPECT_LT(Five, 10); - EXPECT_GT(Ten, 5); - - EXPECT_LE(10, Ten); - EXPECT_GE(10, Ten); - EXPECT_EQ(10, Ten); - EXPECT_FALSE(10 != Ten); - EXPECT_FALSE(10 < Ten); - EXPECT_FALSE(10 > Ten); - EXPECT_LT(5, Ten); - EXPECT_GT(10, Five); -} - -TYPED_TEST(IntTest, overflows) { - TypeParam X(1ll << 60); - EXPECT_EQ((X * X - X * X * X * X) / (X * X * X), 1 - (1ll << 60)); - TypeParam Y(1ll << 62); - EXPECT_EQ((Y + Y + Y + Y + Y + Y) / Y, 6); - EXPECT_EQ(-(2 * (-Y)), 2 * Y); // -(-2^63) overflow. - X *= X; - EXPECT_EQ(X, (Y * Y) / 16); - Y += Y; - Y += Y; - Y += Y; - Y /= 8; - EXPECT_EQ(Y, 1ll << 62); - - TypeParam Min(std::numeric_limits::min()); - TypeParam One(1); - EXPECT_EQ(floorDiv(Min, -One), -Min); - EXPECT_EQ(ceilDiv(Min, -One), -Min); - EXPECT_EQ(abs(Min), -Min); - - TypeParam Z = Min; - Z /= -1; - EXPECT_EQ(Z, -Min); - TypeParam W(Min); - --W; - EXPECT_EQ(W, TypeParam(Min) - 1); - TypeParam U(Min); - U -= 1; - EXPECT_EQ(U, W); - - TypeParam Max(std::numeric_limits::max()); - TypeParam V = Max; - ++V; - EXPECT_EQ(V, Max + 1); - TypeParam T = Max; - T += 1; - EXPECT_EQ(T, V); -} - -TYPED_TEST(IntTest, floorCeilModAbsLcmGcd) { - TypeParam X(1ll << 50), One(1), Two(2), Three(3); - - // Run on small values and large values. - for (const TypeParam &Y : {X, X * X}) { - EXPECT_EQ(floorDiv(3 * Y, Three), Y); - EXPECT_EQ(ceilDiv(3 * Y, Three), Y); - EXPECT_EQ(floorDiv(3 * Y - 1, Three), Y - 1); - EXPECT_EQ(ceilDiv(3 * Y - 1, Three), Y); - EXPECT_EQ(floorDiv(3 * Y - 2, Three), Y - 1); - EXPECT_EQ(ceilDiv(3 * Y - 2, Three), Y); - - EXPECT_EQ(mod(3 * Y, Three), 0); - EXPECT_EQ(mod(3 * Y + 1, Three), One); - EXPECT_EQ(mod(3 * Y + 2, Three), Two); - - EXPECT_EQ(floorDiv(3 * Y, Y), 3); - EXPECT_EQ(ceilDiv(3 * Y, Y), 3); - EXPECT_EQ(floorDiv(3 * Y - 1, Y), 2); - EXPECT_EQ(ceilDiv(3 * Y - 1, Y), 3); - EXPECT_EQ(floorDiv(3 * Y - 2, Y), 2); - EXPECT_EQ(ceilDiv(3 * Y - 2, Y), 3); - - EXPECT_EQ(mod(3 * Y, Y), 0); - EXPECT_EQ(mod(3 * Y + 1, Y), 1); - EXPECT_EQ(mod(3 * Y + 2, Y), 2); - - EXPECT_EQ(abs(Y), Y); - EXPECT_EQ(abs(-Y), Y); - - EXPECT_EQ(gcd(3 * Y, Three), Three); - EXPECT_EQ(lcm(Y, Three), 3 * Y); - EXPECT_EQ(gcd(2 * Y, 3 * Y), Y); - EXPECT_EQ(lcm(2 * Y, 3 * Y), 6 * Y); - EXPECT_EQ(gcd(15 * Y, 6 * Y), 3 * Y); - EXPECT_EQ(lcm(15 * Y, 6 * Y), 30 * Y); - } -} -} // namespace diff --git a/mlir/include/mlir/Analysis/Presburger/Barvinok.h b/mlir/include/mlir/Analysis/Presburger/Barvinok.h index c9a1645b5e632..cd1ea3a9571ba 100644 --- a/mlir/include/mlir/Analysis/Presburger/Barvinok.h +++ b/mlir/include/mlir/Analysis/Presburger/Barvinok.h @@ -74,7 +74,7 @@ inline PolyhedronH defineHRep(int numVars, int numSymbols = 0) { /// Barvinok, A., and J. E. Pommersheim. "An algorithmic theory of lattice /// points in polyhedra." p. 107 If it has more rays than the dimension, return /// 0. -DynamicAPInt getIndex(const ConeV &cone); +MPInt getIndex(const ConeV &cone); /// Given a cone in H-representation, return its dual. The dual cone is in /// V-representation. diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h index 6be132058e6ca..f76ba3f006d07 100644 --- a/mlir/include/mlir/Analysis/Presburger/Fraction.h +++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h @@ -14,11 +14,10 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_FRACTION_H #define MLIR_ANALYSIS_PRESBURGER_FRACTION_H -#include "llvm/ADT/DynamicAPInt.h" +#include "mlir/Analysis/Presburger/MPInt.h" namespace mlir { namespace presburger { -using llvm::DynamicAPInt; /// A class to represent fractions. The sign of the fraction is represented /// in the sign of the numerator; the denominator is always positive. @@ -30,7 +29,7 @@ struct Fraction { Fraction() = default; /// Construct a Fraction from a numerator and denominator. - Fraction(const DynamicAPInt &oNum, const DynamicAPInt &oDen = DynamicAPInt(1)) + Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) { if (den < 0) { num = -num; @@ -38,43 +37,32 @@ struct Fraction { } } /// Overloads for passing literals. - Fraction(const DynamicAPInt &num, int64_t den) - : Fraction(num, DynamicAPInt(den)) {} - Fraction(int64_t num, const DynamicAPInt &den = DynamicAPInt(1)) - : Fraction(DynamicAPInt(num), den) {} - Fraction(int64_t num, int64_t den) - : Fraction(DynamicAPInt(num), DynamicAPInt(den)) {} + Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {} + Fraction(int64_t num, const MPInt &den = MPInt(1)) + : Fraction(MPInt(num), den) {} + Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {} // Return the value of the fraction as an integer. This should only be called // when the fraction's value is really an integer. - DynamicAPInt getAsInteger() const { + MPInt getAsInteger() const { assert(num % den == 0 && "Get as integer called on non-integral fraction!"); return num / den; } - /// The numerator and denominator, respectively. The denominator is always - /// positive. - DynamicAPInt num{0}, den{1}; - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) llvm::raw_ostream &print(llvm::raw_ostream &os) const { return os << "(" << num << "/" << den << ")"; } -#endif -}; -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Fraction &x) { - x.print(os); - return os; -} -#endif + /// The numerator and denominator, respectively. The denominator is always + /// positive. + MPInt num{0}, den{1}; +}; /// Three-way comparison between two fractions. /// Returns +1, 0, and -1 if the first fraction is greater than, equal to, or /// less than the second fraction, respectively. inline int compare(const Fraction &x, const Fraction &y) { - DynamicAPInt diff = x.num * y.den - y.num * x.den; + MPInt diff = x.num * y.den - y.num * x.den; if (diff > 0) return +1; if (diff < 0) @@ -82,9 +70,9 @@ inline int compare(const Fraction &x, const Fraction &y) { return 0; } -inline DynamicAPInt floor(const Fraction &f) { return floorDiv(f.num, f.den); } +inline MPInt floor(const Fraction &f) { return floorDiv(f.num, f.den); } -inline DynamicAPInt ceil(const Fraction &f) { return ceilDiv(f.num, f.den); } +inline MPInt ceil(const Fraction &f) { return ceilDiv(f.num, f.den); } inline Fraction operator-(const Fraction &x) { return Fraction(-x.num, x.den); } @@ -120,7 +108,7 @@ inline Fraction abs(const Fraction &f) { inline Fraction reduce(const Fraction &f) { if (f == Fraction(0)) return Fraction(0, 1); - DynamicAPInt g = gcd(abs(f.num), abs(f.den)); + MPInt g = gcd(abs(f.num), abs(f.den)); return Fraction(f.num / g, f.den / g); } @@ -141,8 +129,8 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) { } // Find the integer nearest to a given fraction. -inline DynamicAPInt round(const Fraction &f) { - DynamicAPInt rem = f.num % f.den; +inline MPInt round(const Fraction &f) { + MPInt rem = f.num % f.den; return (f.num / f.den) + (rem > f.den / 2); } @@ -165,6 +153,12 @@ inline Fraction &operator*=(Fraction &x, const Fraction &y) { x = x * y; return x; } + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Fraction &x) { + x.print(os); + return os; +} + } // namespace presburger } // namespace mlir diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 40e96e2583d22..163f365c623d7 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -20,13 +20,10 @@ #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/DynamicAPInt.h" #include namespace mlir { namespace presburger { -using llvm::DynamicAPInt; -using llvm::int64fromDynamicAPInt; class IntegerRelation; class IntegerPolyhedron; @@ -169,18 +166,16 @@ class IntegerRelation { bool isSubsetOf(const IntegerRelation &other) const; /// Returns the value at the specified equality row and column. - inline DynamicAPInt atEq(unsigned i, unsigned j) const { - return equalities(i, j); - } + inline MPInt atEq(unsigned i, unsigned j) const { return equalities(i, j); } /// The same, but casts to int64_t. This is unsafe and will assert-fail if the /// value does not fit in an int64_t. inline int64_t atEq64(unsigned i, unsigned j) const { return int64_t(equalities(i, j)); } - inline DynamicAPInt &atEq(unsigned i, unsigned j) { return equalities(i, j); } + inline MPInt &atEq(unsigned i, unsigned j) { return equalities(i, j); } /// Returns the value at the specified inequality row and column. - inline DynamicAPInt atIneq(unsigned i, unsigned j) const { + inline MPInt atIneq(unsigned i, unsigned j) const { return inequalities(i, j); } /// The same, but casts to int64_t. This is unsafe and will assert-fail if the @@ -188,9 +183,7 @@ class IntegerRelation { inline int64_t atIneq64(unsigned i, unsigned j) const { return int64_t(inequalities(i, j)); } - inline DynamicAPInt &atIneq(unsigned i, unsigned j) { - return inequalities(i, j); - } + inline MPInt &atIneq(unsigned i, unsigned j) { return inequalities(i, j); } unsigned getNumConstraints() const { return getNumInequalities() + getNumEqualities(); @@ -224,10 +217,10 @@ class IntegerRelation { return inequalities.getNumReservedRows(); } - inline ArrayRef getEquality(unsigned idx) const { + inline ArrayRef getEquality(unsigned idx) const { return equalities.getRow(idx); } - inline ArrayRef getInequality(unsigned idx) const { + inline ArrayRef getInequality(unsigned idx) const { return inequalities.getRow(idx); } /// The same, but casts to int64_t. This is unsafe and will assert-fail if the @@ -304,15 +297,13 @@ class IntegerRelation { unsigned appendVar(VarKind kind, unsigned num = 1); /// Adds an inequality (>= 0) from the coefficients specified in `inEq`. - void addInequality(ArrayRef inEq); + void addInequality(ArrayRef inEq); void addInequality(ArrayRef inEq) { - addInequality(getDynamicAPIntVec(inEq)); + addInequality(getMPIntVec(inEq)); } /// Adds an equality from the coefficients specified in `eq`. - void addEquality(ArrayRef eq); - void addEquality(ArrayRef eq) { - addEquality(getDynamicAPIntVec(eq)); - } + void addEquality(ArrayRef eq); + void addEquality(ArrayRef eq) { addEquality(getMPIntVec(eq)); } /// Eliminate the `posB^th` local variable, replacing every instance of it /// with the `posA^th` local variable. This should be used when the two @@ -347,7 +338,7 @@ class IntegerRelation { /// For a generic integer sampling operation, findIntegerSample is more /// robust and should be preferred. Note that Domain is minimized first, then /// range. - MaybeOptimum> findIntegerLexMin() const; + MaybeOptimum> findIntegerLexMin() const; /// Swap the posA^th variable with the posB^th variable. virtual void swapVar(unsigned posA, unsigned posB); @@ -357,9 +348,9 @@ class IntegerRelation { /// Sets the `values.size()` variables starting at `po`s to the specified /// values and removes them. - void setAndEliminate(unsigned pos, ArrayRef values); + void setAndEliminate(unsigned pos, ArrayRef values); void setAndEliminate(unsigned pos, ArrayRef values) { - setAndEliminate(pos, getDynamicAPIntVec(values)); + setAndEliminate(pos, getMPIntVec(values)); } /// Replaces the contents of this IntegerRelation with `other`. @@ -408,26 +399,26 @@ class IntegerRelation { /// /// Returns an integer sample point if one exists, or an empty Optional /// otherwise. The returned value also includes values of local ids. - std::optional> findIntegerSample() const; + std::optional> findIntegerSample() const; /// Compute an overapproximation of the number of integer points in the /// relation. Symbol vars currently not supported. If the computed /// overapproximation is infinite, an empty optional is returned. - std::optional computeVolume() const; + std::optional computeVolume() const; /// Returns true if the given point satisfies the constraints, or false /// otherwise. Takes the values of all vars including locals. - bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const; bool containsPoint(ArrayRef point) const { - return containsPoint(getDynamicAPIntVec(point)); + return containsPoint(getMPIntVec(point)); } /// Given the values of non-local vars, return a satisfying assignment to the /// local if one exists, or an empty optional otherwise. - std::optional> - containsPointNoLocal(ArrayRef point) const; - std::optional> + std::optional> + containsPointNoLocal(ArrayRef point) const; + std::optional> containsPointNoLocal(ArrayRef point) const { - return containsPointNoLocal(getDynamicAPIntVec(point)); + return containsPointNoLocal(getMPIntVec(point)); } /// Returns a `DivisonRepr` representing the division representation of local @@ -442,16 +433,15 @@ class IntegerRelation { DivisionRepr getLocalReprs(std::vector *repr = nullptr) const; /// Adds a constant bound for the specified variable. - void addBound(BoundType type, unsigned pos, const DynamicAPInt &value); + void addBound(BoundType type, unsigned pos, const MPInt &value); void addBound(BoundType type, unsigned pos, int64_t value) { - addBound(type, pos, DynamicAPInt(value)); + addBound(type, pos, MPInt(value)); } /// Adds a constant bound for the specified expression. - void addBound(BoundType type, ArrayRef expr, - const DynamicAPInt &value); + void addBound(BoundType type, ArrayRef expr, const MPInt &value); void addBound(BoundType type, ArrayRef expr, int64_t value) { - addBound(type, getDynamicAPIntVec(expr), DynamicAPInt(value)); + addBound(type, getMPIntVec(expr), MPInt(value)); } /// Adds a new local variable as the floordiv of an affine function of other @@ -459,10 +449,9 @@ class IntegerRelation { /// respect to a positive constant `divisor`. Two constraints are added to the /// system to capture equivalence with the floordiv: /// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1. - void addLocalFloorDiv(ArrayRef dividend, - const DynamicAPInt &divisor); + void addLocalFloorDiv(ArrayRef dividend, const MPInt &divisor); void addLocalFloorDiv(ArrayRef dividend, int64_t divisor) { - addLocalFloorDiv(getDynamicAPIntVec(dividend), DynamicAPInt(divisor)); + addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor)); } /// Projects out (aka eliminates) `num` variables starting at position @@ -518,11 +507,10 @@ class IntegerRelation { /// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a /// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See /// comments at function definition for examples. - std::optional getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb = nullptr, - DynamicAPInt *boundFloorDivisor = nullptr, - SmallVectorImpl *ub = nullptr, unsigned *minLbPos = nullptr, - unsigned *minUbPos = nullptr) const; + std::optional getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb = nullptr, + MPInt *boundFloorDivisor = nullptr, SmallVectorImpl *ub = nullptr, + unsigned *minLbPos = nullptr, unsigned *minUbPos = nullptr) const; /// The same, but casts to int64_t. This is unsafe and will assert-fail if the /// value does not fit in an int64_t. std::optional getConstantBoundOnDimSize64( @@ -530,30 +518,27 @@ class IntegerRelation { int64_t *boundFloorDivisor = nullptr, SmallVectorImpl *ub = nullptr, unsigned *minLbPos = nullptr, unsigned *minUbPos = nullptr) const { - SmallVector ubDynamicAPInt, lbDynamicAPInt; - DynamicAPInt boundFloorDivisorDynamicAPInt; - std::optional result = getConstantBoundOnDimSize( - pos, &lbDynamicAPInt, &boundFloorDivisorDynamicAPInt, &ubDynamicAPInt, - minLbPos, minUbPos); + SmallVector ubMPInt, lbMPInt; + MPInt boundFloorDivisorMPInt; + std::optional result = getConstantBoundOnDimSize( + pos, &lbMPInt, &boundFloorDivisorMPInt, &ubMPInt, minLbPos, minUbPos); if (lb) - *lb = getInt64Vec(lbDynamicAPInt); + *lb = getInt64Vec(lbMPInt); if (ub) - *ub = getInt64Vec(ubDynamicAPInt); + *ub = getInt64Vec(ubMPInt); if (boundFloorDivisor) - *boundFloorDivisor = static_cast(boundFloorDivisorDynamicAPInt); - return llvm::transformOptional(result, int64fromDynamicAPInt); + *boundFloorDivisor = static_cast(boundFloorDivisorMPInt); + return llvm::transformOptional(result, int64FromMPInt); } /// Returns the constant bound for the pos^th variable if there is one; /// std::nullopt otherwise. - std::optional getConstantBound(BoundType type, - unsigned pos) const; + std::optional getConstantBound(BoundType type, unsigned pos) const; /// The same, but casts to int64_t. This is unsafe and will assert-fail if the /// value does not fit in an int64_t. std::optional getConstantBound64(BoundType type, unsigned pos) const { - return llvm::transformOptional(getConstantBound(type, pos), - int64fromDynamicAPInt); + return llvm::transformOptional(getConstantBound(type, pos), int64FromMPInt); } /// Removes constraints that are independent of (i.e., do not have a @@ -761,13 +746,12 @@ class IntegerRelation { /// Returns the constant lower bound if isLower is true, and the upper /// bound if isLower is false. template - std::optional computeConstantLowerOrUpperBound(unsigned pos); + std::optional computeConstantLowerOrUpperBound(unsigned pos); /// The same, but casts to int64_t. This is unsafe and will assert-fail if the /// value does not fit in an int64_t. template std::optional computeConstantLowerOrUpperBound64(unsigned pos) { - return computeConstantLowerOrUpperBound(pos).map( - int64fromDynamicAPInt); + return computeConstantLowerOrUpperBound(pos).map(int64FromMPInt); } /// Eliminates a single variable at `position` from equality and inequality diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h index eeac9f1dc3938..b5c761439f0b7 100644 --- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h +++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h @@ -40,15 +40,13 @@ class LinearTransform { // The given vector is interpreted as a row vector v. Post-multiply v with // this transform, say T, and return vT. - SmallVector - preMultiplyWithRow(ArrayRef rowVec) const { + SmallVector preMultiplyWithRow(ArrayRef rowVec) const { return matrix.preMultiplyWithRow(rowVec); } // The given vector is interpreted as a column vector v. Pre-multiply v with // this transform, say T, and return Tv. - SmallVector - postMultiplyWithColumn(ArrayRef colVec) const { + SmallVector postMultiplyWithColumn(ArrayRef colVec) const { return matrix.postMultiplyWithColumn(colVec); } diff --git a/mlir/include/mlir/Analysis/Presburger/MPInt.h b/mlir/include/mlir/Analysis/Presburger/MPInt.h new file mode 100644 index 0000000000000..f7678967190a0 --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/MPInt.h @@ -0,0 +1,611 @@ +//===- MPInt.h - MLIR MPInt Class -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any arbitrary values. This is optimized for small-values +// by providing fast-paths for the cases when the value stored fits in 64-bits. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_MPINT_H +#define MLIR_ANALYSIS_PRESBURGER_MPINT_H + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +namespace presburger { +using ::llvm::ArrayRef; +using ::llvm::divideCeilSigned; +using ::llvm::divideFloorSigned; +using ::llvm::mod; + +namespace detail { +/// If builtin intrinsics for overflow-checked arithmetic are available, +/// use them. Otherwise, call through to LLVM's overflow-checked arithmetic +/// functionality. Those functions also have such macro-gated uses of intrinsics +/// but they are not always_inlined, which is important for us to achieve +/// high-performance; calling the functions directly would result in a slowdown +/// of 1.15x. +LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_add_overflow) + return __builtin_add_overflow(x, y, &result); +#else + return llvm::AddOverflow(x, y, result); +#endif +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool subOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_sub_overflow) + return __builtin_sub_overflow(x, y, &result); +#else + return llvm::SubOverflow(x, y, result); +#endif +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool mulOverflow(int64_t x, int64_t y, + int64_t &result) { +#if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(x, y, &result); +#else + return llvm::MulOverflow(x, y, result); +#endif +} +} // namespace detail + +/// This class provides support for multi-precision arithmetic. +/// +/// Unlike APInt, this extends the precision as necessary to prevent overflows +/// and supports operations between objects with differing internal precisions. +/// +/// This is optimized for small-values by providing fast-paths for the cases +/// when the value stored fits in 64-bits. We annotate all fastpaths by using +/// the LLVM_LIKELY/LLVM_UNLIKELY annotations. Removing these would result in +/// a 1.2x performance slowdown. +/// +/// We always_inline all operations; removing these results in a 1.5x +/// performance slowdown. +/// +/// When holdsLarge is true, a SlowMPInt is held in the union. If it is false, +/// the int64_t is held. Using std::variant instead would lead to significantly +/// worse performance. +class MPInt { +private: + union { + int64_t valSmall; + detail::SlowMPInt valLarge; + }; + unsigned holdsLarge; + + LLVM_ATTRIBUTE_ALWAYS_INLINE void initSmall(int64_t o) { + if (LLVM_UNLIKELY(isLarge())) + valLarge.detail::SlowMPInt::~SlowMPInt(); + valSmall = o; + holdsLarge = false; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE void initLarge(const detail::SlowMPInt &o) { + if (LLVM_LIKELY(isSmall())) { + // The data in memory could be in an arbitrary state, not necessarily + // corresponding to any valid state of valLarge; we cannot call any member + // functions, e.g. the assignment operator on it, as they may access the + // invalid internal state. We instead construct a new object using + // placement new. + new (&valLarge) detail::SlowMPInt(o); + } else { + // In this case, we need to use the assignment operator, because if we use + // placement-new as above we would lose track of allocated memory + // and leak it. + valLarge = o; + } + holdsLarge = true; + } + + LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(const detail::SlowMPInt &val) + : valLarge(val), holdsLarge(true) {} + LLVM_ATTRIBUTE_ALWAYS_INLINE bool isSmall() const { return !holdsLarge; } + LLVM_ATTRIBUTE_ALWAYS_INLINE bool isLarge() const { return holdsLarge; } + /// Get the stored value. For getSmall/Large, + /// the stored value should be small/large. + LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t getSmall() const { + assert(isSmall() && + "getSmall should only be called when the value stored is small!"); + return valSmall; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t &getSmall() { + assert(isSmall() && + "getSmall should only be called when the value stored is small!"); + return valSmall; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE const detail::SlowMPInt &getLarge() const { + assert(isLarge() && + "getLarge should only be called when the value stored is large!"); + return valLarge; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE detail::SlowMPInt &getLarge() { + assert(isLarge() && + "getLarge should only be called when the value stored is large!"); + return valLarge; + } + explicit operator detail::SlowMPInt() const { + if (isSmall()) + return detail::SlowMPInt(getSmall()); + return getLarge(); + } + +public: + LLVM_ATTRIBUTE_ALWAYS_INLINE explicit MPInt(int64_t val) + : valSmall(val), holdsLarge(false) {} + LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt() : MPInt(0) {} + LLVM_ATTRIBUTE_ALWAYS_INLINE ~MPInt() { + if (LLVM_UNLIKELY(isLarge())) + valLarge.detail::SlowMPInt::~SlowMPInt(); + } + LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt(const MPInt &o) + : valSmall(o.valSmall), holdsLarge(false) { + if (LLVM_UNLIKELY(o.isLarge())) + initLarge(o.valLarge); + } + LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(const MPInt &o) { + if (LLVM_LIKELY(o.isSmall())) { + initSmall(o.valSmall); + return *this; + } + initLarge(o.valLarge); + return *this; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator=(int x) { + initSmall(x); + return *this; + } + LLVM_ATTRIBUTE_ALWAYS_INLINE explicit operator int64_t() const { + if (isSmall()) + return getSmall(); + return static_cast(getLarge()); + } + + bool operator==(const MPInt &o) const; + bool operator!=(const MPInt &o) const; + bool operator>(const MPInt &o) const; + bool operator<(const MPInt &o) const; + bool operator<=(const MPInt &o) const; + bool operator>=(const MPInt &o) const; + MPInt operator+(const MPInt &o) const; + MPInt operator-(const MPInt &o) const; + MPInt operator*(const MPInt &o) const; + MPInt operator/(const MPInt &o) const; + MPInt operator%(const MPInt &o) const; + MPInt &operator+=(const MPInt &o); + MPInt &operator-=(const MPInt &o); + MPInt &operator*=(const MPInt &o); + MPInt &operator/=(const MPInt &o); + MPInt &operator%=(const MPInt &o); + MPInt operator-() const; + MPInt &operator++(); + MPInt &operator--(); + + // Divide by a number that is known to be positive. + // This is slightly more efficient because it saves an overflow check. + MPInt divByPositive(const MPInt &o) const; + MPInt &divByPositiveInPlace(const MPInt &o); + + friend MPInt abs(const MPInt &x); + friend MPInt gcdRange(ArrayRef range); + friend MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs); + friend MPInt floorDiv(const MPInt &lhs, const MPInt &rhs); + // The operands must be non-negative for gcd. + friend MPInt gcd(const MPInt &a, const MPInt &b); + friend MPInt lcm(const MPInt &a, const MPInt &b); + friend MPInt mod(const MPInt &lhs, const MPInt &rhs); + + llvm::raw_ostream &print(llvm::raw_ostream &os) const; + void dump() const; + + /// --------------------------------------------------------------------------- + /// Convenience operator overloads for int64_t. + /// --------------------------------------------------------------------------- + friend MPInt &operator+=(MPInt &a, int64_t b); + friend MPInt &operator-=(MPInt &a, int64_t b); + friend MPInt &operator*=(MPInt &a, int64_t b); + friend MPInt &operator/=(MPInt &a, int64_t b); + friend MPInt &operator%=(MPInt &a, int64_t b); + + friend bool operator==(const MPInt &a, int64_t b); + friend bool operator!=(const MPInt &a, int64_t b); + friend bool operator>(const MPInt &a, int64_t b); + friend bool operator<(const MPInt &a, int64_t b); + friend bool operator<=(const MPInt &a, int64_t b); + friend bool operator>=(const MPInt &a, int64_t b); + friend MPInt operator+(const MPInt &a, int64_t b); + friend MPInt operator-(const MPInt &a, int64_t b); + friend MPInt operator*(const MPInt &a, int64_t b); + friend MPInt operator/(const MPInt &a, int64_t b); + friend MPInt operator%(const MPInt &a, int64_t b); + + friend bool operator==(int64_t a, const MPInt &b); + friend bool operator!=(int64_t a, const MPInt &b); + friend bool operator>(int64_t a, const MPInt &b); + friend bool operator<(int64_t a, const MPInt &b); + friend bool operator<=(int64_t a, const MPInt &b); + friend bool operator>=(int64_t a, const MPInt &b); + friend MPInt operator+(int64_t a, const MPInt &b); + friend MPInt operator-(int64_t a, const MPInt &b); + friend MPInt operator*(int64_t a, const MPInt &b); + friend MPInt operator/(int64_t a, const MPInt &b); + friend MPInt operator%(int64_t a, const MPInt &b); + + friend llvm::hash_code hash_value(const MPInt &x); // NOLINT +}; + +/// Redeclarations of friend declaration above to +/// make it discoverable by lookups. +llvm::hash_code hash_value(const MPInt &x); // NOLINT + +/// This just calls through to the operator int64_t, but it's useful when a +/// function pointer is required. (Although this is marked inline, it is still +/// possible to obtain and use a function pointer to this.) +static inline int64_t int64FromMPInt(const MPInt &x) { return int64_t(x); } +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mpintFromInt64(int64_t x) { + return MPInt(x); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const MPInt &x); + +// The RHS is always expected to be positive, and the result +/// is always non-negative. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &lhs, const MPInt &rhs); + +namespace detail { +// Division overflows only when trying to negate the minimal signed value. +LLVM_ATTRIBUTE_ALWAYS_INLINE bool divWouldOverflow(int64_t x, int64_t y) { + return x == std::numeric_limits::min() && y == -1; +} +} // namespace detail + +/// We define the operations here in the header to facilitate inlining. + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator==(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() == o.getSmall(); + return detail::SlowMPInt(*this) == detail::SlowMPInt(o); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator!=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() != o.getSmall(); + return detail::SlowMPInt(*this) != detail::SlowMPInt(o); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() > o.getSmall(); + return detail::SlowMPInt(*this) > detail::SlowMPInt(o); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() < o.getSmall(); + return detail::SlowMPInt(*this) < detail::SlowMPInt(o); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator<=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() <= o.getSmall(); + return detail::SlowMPInt(*this) <= detail::SlowMPInt(o); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool MPInt::operator>=(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return getSmall() >= o.getSmall(); + return detail::SlowMPInt(*this) >= detail::SlowMPInt(o); +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator+(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::addOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::subOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator*(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + MPInt result; + bool overflow = + detail::mulOverflow(getSmall(), o.getSmall(), result.getSmall()); + if (LLVM_LIKELY(!overflow)) + return result; + return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); + } + return MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); +} + +// Division overflows only occur when negating the minimal possible value. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::divByPositive(const MPInt &o) const { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(getSmall() / o.getSmall()); + return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator/(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + // Division overflows only occur when negating the minimal possible value. + if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), o.getSmall()))) + return -*this; + return MPInt(getSmall() / o.getSmall()); + } + return MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x) { + return MPInt(x >= 0 ? x : -x); +} +// Division overflows only occur when negating the minimal possible value. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (LLVM_UNLIKELY(detail::divWouldOverflow(lhs.getSmall(), rhs.getSmall()))) + return -lhs; + return MPInt(divideCeilSigned(lhs.getSmall(), rhs.getSmall())); + } + return MPInt(ceilDiv(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt floorDiv(const MPInt &lhs, + const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) { + if (LLVM_UNLIKELY(detail::divWouldOverflow(lhs.getSmall(), rhs.getSmall()))) + return -lhs; + return MPInt(divideFloorSigned(lhs.getSmall(), rhs.getSmall())); + } + return MPInt(floorDiv(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt mod(const MPInt &lhs, const MPInt &rhs) { + if (LLVM_LIKELY(lhs.isSmall() && rhs.isSmall())) + return MPInt(mod(lhs.getSmall(), rhs.getSmall())); + return MPInt(mod(detail::SlowMPInt(lhs), detail::SlowMPInt(rhs))); +} + +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b) { + assert(a >= 0 && b >= 0 && "operands must be non-negative!"); + if (LLVM_LIKELY(a.isSmall() && b.isSmall())) + return MPInt(std::gcd(a.getSmall(), b.getSmall())); + return MPInt(gcd(detail::SlowMPInt(a), detail::SlowMPInt(b))); +} + +/// Returns the least common multiple of 'a' and 'b'. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt lcm(const MPInt &a, const MPInt &b) { + MPInt x = abs(a); + MPInt y = abs(b); + return (x * y) / gcd(x, y); +} + +/// This operation cannot overflow. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator%(const MPInt &o) const { + if (LLVM_LIKELY(isSmall() && o.isSmall())) + return MPInt(getSmall() % o.getSmall()); + return MPInt(detail::SlowMPInt(*this) % detail::SlowMPInt(o)); +} + +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt MPInt::operator-() const { + if (LLVM_LIKELY(isSmall())) { + if (LLVM_LIKELY(getSmall() != std::numeric_limits::min())) + return MPInt(-getSmall()); + return MPInt(-detail::SlowMPInt(*this)); + } + return MPInt(-detail::SlowMPInt(*this)); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator+=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::addOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) + detail::SlowMPInt(o)); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator-=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::subOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) - detail::SlowMPInt(o)); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator*=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + int64_t result = getSmall(); + bool overflow = detail::mulOverflow(getSmall(), o.getSmall(), result); + if (LLVM_LIKELY(!overflow)) { + getSmall() = result; + return *this; + } + // Note: this return is not strictly required but + // removing it leads to a performance regression. + return *this = MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); + } + return *this = MPInt(detail::SlowMPInt(*this) * detail::SlowMPInt(o)); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator/=(const MPInt &o) { + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + // Division overflows only occur when negating the minimal possible value. + if (LLVM_UNLIKELY(detail::divWouldOverflow(getSmall(), o.getSmall()))) + return *this = -*this; + getSmall() /= o.getSmall(); + return *this; + } + return *this = MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +// Division overflows only occur when the divisor is -1. +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt & +MPInt::divByPositiveInPlace(const MPInt &o) { + assert(o > 0); + if (LLVM_LIKELY(isSmall() && o.isSmall())) { + getSmall() /= o.getSmall(); + return *this; + } + return *this = MPInt(detail::SlowMPInt(*this) / detail::SlowMPInt(o)); +} + +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator%=(const MPInt &o) { + return *this = *this % o; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator++() { return *this += 1; } +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &MPInt::operator--() { return *this -= 1; } + +/// ---------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// ---------------------------------------------------------------------------- +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator+=(MPInt &a, int64_t b) { + return a = a + b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator-=(MPInt &a, int64_t b) { + return a = a - b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator*=(MPInt &a, int64_t b) { + return a = a * b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator/=(MPInt &a, int64_t b) { + return a = a / b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt &operator%=(MPInt &a, int64_t b) { + return a = a % b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator+(const MPInt &a, int64_t b) { + return a + MPInt(b); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator-(const MPInt &a, int64_t b) { + return a - MPInt(b); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator*(const MPInt &a, int64_t b) { + return a * MPInt(b); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator/(const MPInt &a, int64_t b) { + return a / MPInt(b); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator%(const MPInt &a, int64_t b) { + return a % MPInt(b); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator+(int64_t a, const MPInt &b) { + return MPInt(a) + b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator-(int64_t a, const MPInt &b) { + return MPInt(a) - b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator*(int64_t a, const MPInt &b) { + return MPInt(a) * b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator/(int64_t a, const MPInt &b) { + return MPInt(a) / b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt operator%(int64_t a, const MPInt &b) { + return MPInt(a) % b; +} + +/// We provide special implementations of the comparison operators rather than +/// calling through as above, as this would result in a 1.2x slowdown. +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator==(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() == b; + return a.getLarge() == b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator!=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() != b; + return a.getLarge() != b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() > b; + return a.getLarge() > b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() < b; + return a.getLarge() < b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() <= b; + return a.getLarge() <= b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>=(const MPInt &a, int64_t b) { + if (LLVM_LIKELY(a.isSmall())) + return a.getSmall() >= b; + return a.getLarge() >= b; +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator==(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a == b.getSmall(); + return a == b.getLarge(); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator!=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a != b.getSmall(); + return a != b.getLarge(); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a > b.getSmall(); + return a > b.getLarge(); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a < b.getSmall(); + return a < b.getLarge(); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator<=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a <= b.getSmall(); + return a <= b.getLarge(); +} +LLVM_ATTRIBUTE_ALWAYS_INLINE bool operator>=(int64_t a, const MPInt &b) { + if (LLVM_LIKELY(b.isSmall())) + return a >= b.getSmall(); + return a >= b.getLarge(); +} + +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_MPINT_H diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index e232ecd5e1509..c20a7bcecd52d 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// // // This is a simple 2D matrix class that supports reading, writing, resizing, -// swapping rows, and swapping columns. It can hold integers (DynamicAPInt) or -// rational numbers (Fraction). +// swapping rows, and swapping columns. It can hold integers (MPInt) or rational +// numbers (Fraction). // //===----------------------------------------------------------------------===// @@ -34,13 +34,13 @@ namespace presburger { /// (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused /// columns always have all zero values. The reserved rows are just reserved /// space in the underlying SmallVector's capacity. -/// This class only works for the types DynamicAPInt and Fraction, since the -/// method implementations are in the Matrix.cpp file. Only these two types have +/// This class only works for the types MPInt and Fraction, since the method +/// implementations are in the Matrix.cpp file. Only these two types have /// been explicitly instantiated there. template class Matrix { - static_assert(std::is_same_v || std::is_same_v, - "T must be DynamicAPInt or Fraction."); + static_assert(std::is_same_v || std::is_same_v, + "T must be MPInt or Fraction."); public: Matrix() = delete; @@ -244,19 +244,19 @@ class Matrix { SmallVector data; }; -extern template class Matrix; +extern template class Matrix; extern template class Matrix; // An inherited class for integer matrices, with no new data attributes. // This is only used for the matrix-related methods which apply only // to integers (hermite normal form computation and row normalisation). -class IntMatrix : public Matrix { +class IntMatrix : public Matrix { public: IntMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0, unsigned reservedColumns = 0) - : Matrix(rows, columns, reservedRows, reservedColumns) {} + : Matrix(rows, columns, reservedRows, reservedColumns){}; - IntMatrix(Matrix m) : Matrix(std::move(m)) {} + IntMatrix(Matrix m) : Matrix(std::move(m)){}; /// Return the identity matrix of the specified dimension. static IntMatrix identity(unsigned dimension); @@ -275,10 +275,10 @@ class IntMatrix : public Matrix { /// Divide the first `nCols` of the specified row by their GCD. /// Returns the GCD of the first `nCols` of the specified row. - DynamicAPInt normalizeRow(unsigned row, unsigned nCols); + MPInt normalizeRow(unsigned row, unsigned nCols); /// Divide the columns of the specified row by their GCD. /// Returns the GCD of the columns of the specified row. - DynamicAPInt normalizeRow(unsigned row); + MPInt normalizeRow(unsigned row); // Compute the determinant of the matrix (cubic time). // Stores the integer inverse of the matrix in the pointer @@ -287,7 +287,7 @@ class IntMatrix : public Matrix { // For a matrix M, the integer inverse is the matrix M' such that // M x M' = M'  M = det(M) x I. // Assert-fails if the matrix is not square. - DynamicAPInt determinant(IntMatrix *inverse = nullptr) const; + MPInt determinant(IntMatrix *inverse = nullptr) const; }; // An inherited class for rational matrices, with no new data attributes. diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h index fcc39bf0e0537..236cc90ad66ac 100644 --- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h +++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h @@ -67,9 +67,7 @@ class MultiAffineFunction { /// Get a matrix with each row representing row^th output expression. const IntMatrix &getOutputMatrix() const { return output; } /// Get the `i^th` output expression. - ArrayRef getOutputExpr(unsigned i) const { - return output.getRow(i); - } + ArrayRef getOutputExpr(unsigned i) const { return output.getRow(i); } /// Get the divisions used in this function. const DivisionRepr &getDivs() const { return divs; } @@ -82,9 +80,9 @@ class MultiAffineFunction { void mergeDivs(MultiAffineFunction &other); //// Return the output of the function at the given point. - SmallVector valueAt(ArrayRef point) const; - SmallVector valueAt(ArrayRef point) const { - return valueAt(getDynamicAPIntVec(point)); + SmallVector valueAt(ArrayRef point) const; + SmallVector valueAt(ArrayRef point) const { + return valueAt(getMPIntVec(point)); } /// Return whether the `this` and `other` are equal when the domain is @@ -193,11 +191,9 @@ class PWMAFunction { PresburgerSet getDomain() const; /// Return the output of the function at the given point. - std::optional> - valueAt(ArrayRef point) const; - std::optional> - valueAt(ArrayRef point) const { - return valueAt(getDynamicAPIntVec(point)); + std::optional> valueAt(ArrayRef point) const; + std::optional> valueAt(ArrayRef point) const { + return valueAt(getMPIntVec(point)); } /// Return all the pieces of this piece-wise function. diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h index f7e06a6b22a95..9634df6d58a1a 100644 --- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h @@ -142,9 +142,9 @@ class PresburgerRelation { SymbolicLexOpt findSymbolicIntegerLexMax() const; /// Return true if the set contains the given point, and false otherwise. - bool containsPoint(ArrayRef point) const; + bool containsPoint(ArrayRef point) const; bool containsPoint(ArrayRef point) const { - return containsPoint(getDynamicAPIntVec(point)); + return containsPoint(getMPIntVec(point)); } /// Return the complement of this set. All local variables in the set must @@ -187,7 +187,7 @@ class PresburgerRelation { /// Find an integer sample from the given set. This should not be called if /// any of the disjuncts in the union are unbounded. - bool findIntegerSample(SmallVectorImpl &sample); + bool findIntegerSample(SmallVectorImpl &sample); /// Compute an overapproximation of the number of integer points in the /// disjunct. Symbol vars are currently not supported. If the computed @@ -196,7 +196,7 @@ class PresburgerRelation { /// This currently just sums up the overapproximations of the volumes of the /// disjuncts, so the approximation might be far from the true volume in the /// case when there is a lot of overlap between disjuncts. - std::optional computeVolume() const; + std::optional computeVolume() const; /// Simplifies the representation of a PresburgerRelation. /// diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index ff26e94e019c8..7ee74c150867c 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -166,7 +166,7 @@ class SimplexBase { /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n /// is the current number of variables, then the corresponding inequality is /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0. - virtual void addInequality(ArrayRef coeffs) = 0; + virtual void addInequality(ArrayRef coeffs) = 0; /// Returns the number of variables in the tableau. unsigned getNumVariables() const; @@ -177,7 +177,7 @@ class SimplexBase { /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n /// is the current number of variables, then the corresponding equality is /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0. - void addEquality(ArrayRef coeffs); + void addEquality(ArrayRef coeffs); /// Add new variables to the end of the list of variables. void appendVariable(unsigned count = 1); @@ -186,8 +186,7 @@ class SimplexBase { /// integer value is the floor div of `coeffs` and `denom`. /// /// `denom` must be positive. - void addDivisionVariable(ArrayRef coeffs, - const DynamicAPInt &denom); + void addDivisionVariable(ArrayRef coeffs, const MPInt &denom); /// Mark the tableau as being empty. void markEmpty(); @@ -296,7 +295,7 @@ class SimplexBase { /// con. /// /// Returns the index of the new Unknown in con. - unsigned addRow(ArrayRef coeffs, bool makeRestricted = false); + unsigned addRow(ArrayRef coeffs, bool makeRestricted = false); /// Swap the two rows/columns in the tableau and associated data structures. void swapRows(unsigned i, unsigned j); @@ -424,7 +423,7 @@ class LexSimplexBase : public SimplexBase { /// /// This just adds the inequality to the tableau and does not try to create a /// consistent tableau configuration. - void addInequality(ArrayRef coeffs) final; + void addInequality(ArrayRef coeffs) final; /// Get a snapshot of the current state. This is used for rolling back. unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); } @@ -496,15 +495,15 @@ class LexSimplex : public LexSimplexBase { /// /// Note: this should be used only when the lexmin is really needed. To obtain /// any integer sample, use Simplex::findIntegerSample as that is more robust. - MaybeOptimum> findIntegerLexMin(); + MaybeOptimum> findIntegerLexMin(); /// Return whether the specified inequality is redundant/separate for the /// polytope. Redundant means every point satisfies the given inequality, and /// separate means no point satisfies it. /// /// These checks are integer-exact. - bool isSeparateInequality(ArrayRef coeffs); - bool isRedundantInequality(ArrayRef coeffs); + bool isSeparateInequality(ArrayRef coeffs); + bool isRedundantInequality(ArrayRef coeffs); private: /// Returns the current sample point, which may contain non-integer (rational) @@ -657,11 +656,11 @@ class SymbolicLexSimplex : public LexSimplexBase { /// Get the numerator of the symbolic sample of the specific row. /// This is an affine expression in the symbols with integer coefficients. /// The last element is the constant term. This ignores the big M coefficient. - SmallVector getSymbolicSampleNumerator(unsigned row) const; + SmallVector getSymbolicSampleNumerator(unsigned row) const; /// Get an affine inequality in the symbols with integer coefficients that /// holds iff the symbolic sample of the specified row is non-negative. - SmallVector getSymbolicSampleIneq(unsigned row) const; + SmallVector getSymbolicSampleIneq(unsigned row) const; /// Return whether all the coefficients of the symbolic sample are integers. /// @@ -711,7 +710,7 @@ class Simplex : public SimplexBase { /// /// This also tries to restore the tableau configuration to a consistent /// state and marks the Simplex empty if this is not possible. - void addInequality(ArrayRef coeffs) final; + void addInequality(ArrayRef coeffs) final; /// Compute the maximum or minimum value of the given row, depending on /// direction. The specified row is never pivoted. On return, the row may @@ -727,7 +726,7 @@ class Simplex : public SimplexBase { /// Returns a Fraction denoting the optimum, or a null value if no optimum /// exists, i.e., if the expression is unbounded in this direction. MaybeOptimum computeOptimum(Direction direction, - ArrayRef coeffs); + ArrayRef coeffs); /// Returns whether the perpendicular of the specified constraint is a /// is a direction along which the polytope is bounded. @@ -769,14 +768,14 @@ class Simplex : public SimplexBase { /// Returns a (min, max) pair denoting the minimum and maximum integer values /// of the given expression. If no integer value exists, both results will be /// of kind Empty. - std::pair, MaybeOptimum> - computeIntegerBounds(ArrayRef coeffs); + std::pair, MaybeOptimum> + computeIntegerBounds(ArrayRef coeffs); /// Check if the simplex takes only one rational value along the /// direction of `coeffs`. /// /// `this` must be nonempty. - bool isFlatAlong(ArrayRef coeffs); + bool isFlatAlong(ArrayRef coeffs); /// Returns true if the polytope is unbounded, i.e., extends to infinity in /// some direction. Otherwise, returns false. @@ -788,7 +787,7 @@ class Simplex : public SimplexBase { /// Returns an integer sample point if one exists, or std::nullopt /// otherwise. This should only be called for bounded sets. - std::optional> findIntegerSample(); + std::optional> findIntegerSample(); enum class IneqType { Redundant, Cut, Separate }; @@ -798,13 +797,13 @@ class Simplex : public SimplexBase { /// Redundant The inequality is satisfied in the polytope /// Cut The inequality is satisfied by some points, but not by others /// Separate The inequality is not satisfied by any point - IneqType findIneqType(ArrayRef coeffs); + IneqType findIneqType(ArrayRef coeffs); /// Check if the specified inequality already holds in the polytope. - bool isRedundantInequality(ArrayRef coeffs); + bool isRedundantInequality(ArrayRef coeffs); /// Check if the specified equality already holds in the polytope. - bool isRedundantEquality(ArrayRef coeffs); + bool isRedundantEquality(ArrayRef coeffs); /// Returns true if this Simplex's polytope is a rational subset of `rel`. /// Otherwise, returns false. @@ -812,7 +811,7 @@ class Simplex : public SimplexBase { /// Returns the current sample point if it is integral. Otherwise, returns /// std::nullopt. - std::optional> getSamplePointIfIntegral() const; + std::optional> getSamplePointIfIntegral() const; /// Returns the current sample point, which may contain non-integer (rational) /// coordinates. Returns an empty optional when the tableau is empty. diff --git a/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h b/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h new file mode 100644 index 0000000000000..482581c573cea --- /dev/null +++ b/mlir/include/mlir/Analysis/Presburger/SlowMPInt.h @@ -0,0 +1,135 @@ +//===- SlowMPInt.h - MLIR SlowMPInt Class -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a simple class to represent arbitrary precision signed integers. +// Unlike APInt, one does not have to specify a fixed maximum size, and the +// integer can take on any arbitrary values. +// +// This class is to be used as a fallback slow path for the MPInt class, and +// is not intended to be used directly. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H +#define MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace presburger { +namespace detail { + +/// A simple class providing multi-precision arithmetic. Internally, it stores +/// an APInt, whose width is doubled whenever an overflow occurs at a certain +/// width. The default constructor sets the initial width to 64. SlowMPInt is +/// primarily intended to be used as a slow fallback path for the upcoming MPInt +/// class. +class SlowMPInt { +private: + llvm::APInt val; + +public: + explicit SlowMPInt(int64_t val); + SlowMPInt(); + explicit SlowMPInt(const llvm::APInt &val); + SlowMPInt &operator=(int64_t val); + explicit operator int64_t() const; + SlowMPInt operator-() const; + bool operator==(const SlowMPInt &o) const; + bool operator!=(const SlowMPInt &o) const; + bool operator>(const SlowMPInt &o) const; + bool operator<(const SlowMPInt &o) const; + bool operator<=(const SlowMPInt &o) const; + bool operator>=(const SlowMPInt &o) const; + SlowMPInt operator+(const SlowMPInt &o) const; + SlowMPInt operator-(const SlowMPInt &o) const; + SlowMPInt operator*(const SlowMPInt &o) const; + SlowMPInt operator/(const SlowMPInt &o) const; + SlowMPInt operator%(const SlowMPInt &o) const; + SlowMPInt &operator+=(const SlowMPInt &o); + SlowMPInt &operator-=(const SlowMPInt &o); + SlowMPInt &operator*=(const SlowMPInt &o); + SlowMPInt &operator/=(const SlowMPInt &o); + SlowMPInt &operator%=(const SlowMPInt &o); + + SlowMPInt &operator++(); + SlowMPInt &operator--(); + + friend SlowMPInt abs(const SlowMPInt &x); + friend SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); + friend SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); + /// The operands must be non-negative for gcd. + friend SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b); + + /// Overload to compute a hash_code for a SlowMPInt value. + friend llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT + + void print(llvm::raw_ostream &os) const; + void dump() const; + + unsigned getBitWidth() const { return val.getBitWidth(); } +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const SlowMPInt &x); + +/// Returns the remainder of dividing LHS by RHS. +/// +/// The RHS is always expected to be positive, and the result +/// is always non-negative. +SlowMPInt mod(const SlowMPInt &lhs, const SlowMPInt &rhs); + +/// Returns the least common multiple of 'a' and 'b'. +SlowMPInt lcm(const SlowMPInt &a, const SlowMPInt &b); + +/// Redeclarations of friend declarations above to +/// make it discoverable by lookups. +SlowMPInt abs(const SlowMPInt &x); +SlowMPInt ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); +SlowMPInt floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs); +SlowMPInt gcd(const SlowMPInt &a, const SlowMPInt &b); +llvm::hash_code hash_value(const SlowMPInt &x); // NOLINT + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +SlowMPInt &operator+=(SlowMPInt &a, int64_t b); +SlowMPInt &operator-=(SlowMPInt &a, int64_t b); +SlowMPInt &operator*=(SlowMPInt &a, int64_t b); +SlowMPInt &operator/=(SlowMPInt &a, int64_t b); +SlowMPInt &operator%=(SlowMPInt &a, int64_t b); + +bool operator==(const SlowMPInt &a, int64_t b); +bool operator!=(const SlowMPInt &a, int64_t b); +bool operator>(const SlowMPInt &a, int64_t b); +bool operator<(const SlowMPInt &a, int64_t b); +bool operator<=(const SlowMPInt &a, int64_t b); +bool operator>=(const SlowMPInt &a, int64_t b); +SlowMPInt operator+(const SlowMPInt &a, int64_t b); +SlowMPInt operator-(const SlowMPInt &a, int64_t b); +SlowMPInt operator*(const SlowMPInt &a, int64_t b); +SlowMPInt operator/(const SlowMPInt &a, int64_t b); +SlowMPInt operator%(const SlowMPInt &a, int64_t b); + +bool operator==(int64_t a, const SlowMPInt &b); +bool operator!=(int64_t a, const SlowMPInt &b); +bool operator>(int64_t a, const SlowMPInt &b); +bool operator<(int64_t a, const SlowMPInt &b); +bool operator<=(int64_t a, const SlowMPInt &b); +bool operator>=(int64_t a, const SlowMPInt &b); +SlowMPInt operator+(int64_t a, const SlowMPInt &b); +SlowMPInt operator-(int64_t a, const SlowMPInt &b); +SlowMPInt operator*(int64_t a, const SlowMPInt &b); +SlowMPInt operator/(int64_t a, const SlowMPInt &b); +SlowMPInt operator%(int64_t a, const SlowMPInt &b); +} // namespace detail +} // namespace presburger +} // namespace mlir + +#endif // MLIR_ANALYSIS_PRESBURGER_SLOWMPINT_H diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h index 9b93e52b48490..38262a65f9754 100644 --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -13,8 +13,8 @@ #ifndef MLIR_ANALYSIS_PRESBURGER_UTILS_H #define MLIR_ANALYSIS_PRESBURGER_UTILS_H +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/DynamicAPInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -23,6 +23,7 @@ namespace mlir { namespace presburger { + class IntegerRelation; /// This class represents the result of operations optimizing something subject @@ -117,7 +118,7 @@ struct MaybeLocalRepr { class DivisionRepr { public: DivisionRepr(unsigned numVars, unsigned numDivs) - : dividends(numDivs, numVars + 1), denoms(numDivs, DynamicAPInt(0)) {} + : dividends(numDivs, numVars + 1), denoms(numDivs, MPInt(0)) {} DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {} @@ -136,27 +137,21 @@ class DivisionRepr { void clearRepr(unsigned i) { denoms[i] = 0; } // Get the dividend of the `i^th` division. - MutableArrayRef getDividend(unsigned i) { - return dividends.getRow(i); - } - ArrayRef getDividend(unsigned i) const { - return dividends.getRow(i); - } + MutableArrayRef getDividend(unsigned i) { return dividends.getRow(i); } + ArrayRef getDividend(unsigned i) const { return dividends.getRow(i); } // For a given point containing values for each variable other than the // division variables, try to find the values for each division variable from // their division representation. - SmallVector, 4> - divValuesAt(ArrayRef point) const; + SmallVector, 4> divValuesAt(ArrayRef point) const; // Get the `i^th` denominator. - DynamicAPInt &getDenom(unsigned i) { return denoms[i]; } - DynamicAPInt getDenom(unsigned i) const { return denoms[i]; } + MPInt &getDenom(unsigned i) { return denoms[i]; } + MPInt getDenom(unsigned i) const { return denoms[i]; } - ArrayRef getDenoms() const { return denoms; } + ArrayRef getDenoms() const { return denoms; } - void setDiv(unsigned i, ArrayRef dividend, - const DynamicAPInt &divisor) { + void setDiv(unsigned i, ArrayRef dividend, const MPInt &divisor) { dividends.setRow(i, dividend); denoms[i] = divisor; } @@ -166,8 +161,7 @@ class DivisionRepr { // simplify the expression. void normalizeDivs(); - void insertDiv(unsigned pos, ArrayRef dividend, - const DynamicAPInt &divisor); + void insertDiv(unsigned pos, ArrayRef dividend, const MPInt &divisor); void insertDiv(unsigned pos, unsigned num = 1); /// Removes duplicate divisions. On every possible duplicate division found, @@ -193,7 +187,7 @@ class DivisionRepr { /// Denominators of each division. If a denominator of a division is `0`, the /// division variable is considered to not have a division representation. /// Otherwise, the denominator is positive. - SmallVector denoms; + SmallVector denoms; }; /// If `q` is defined to be equal to `expr floordiv d`, this equivalent to @@ -211,12 +205,12 @@ class DivisionRepr { /// The coefficient of `q` in `dividend` must be zero, as it is not allowed for /// local variable to be a floor division of an expression involving itself. /// The divisor must be positive. -SmallVector getDivUpperBound(ArrayRef dividend, - const DynamicAPInt &divisor, - unsigned localVarIdx); -SmallVector getDivLowerBound(ArrayRef dividend, - const DynamicAPInt &divisor, - unsigned localVarIdx); +SmallVector getDivUpperBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx); +SmallVector getDivLowerBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx); llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset, unsigned numSet); @@ -225,10 +219,10 @@ llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset, /// function of other variables (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each variable indicating if the /// explicit representation for that variable has already been computed. -/// Return the given array as an array of DynamicAPInts. -SmallVector getDynamicAPIntVec(ArrayRef range); +/// Return the given array as an array of MPInts. +SmallVector getMPIntVec(ArrayRef range); /// Return the given array as an array of int64_t. -SmallVector getInt64Vec(ArrayRef range); +SmallVector getInt64Vec(ArrayRef range); /// Returns the `MaybeLocalRepr` struct which contains the indices of the /// constraints that can be expressed as a floordiv of an affine function. If @@ -237,8 +231,8 @@ SmallVector getInt64Vec(ArrayRef range); /// not be computed, the kind attribute in `MaybeLocalRepr` is set to None. MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - MutableArrayRef dividend, - DynamicAPInt &divisor); + MutableArrayRef dividend, + MPInt &divisor); /// The following overload using int64_t is required for a callsite in /// AffineStructures.h. @@ -263,25 +257,25 @@ void mergeLocalVars(IntegerRelation &relA, IntegerRelation &relB, llvm::function_ref merge); /// Compute the gcd of the range. -DynamicAPInt gcdRange(ArrayRef range); +MPInt gcdRange(ArrayRef range); /// Divide the range by its gcd and return the gcd. -DynamicAPInt normalizeRange(MutableArrayRef range); +MPInt normalizeRange(MutableArrayRef range); /// Normalize the given (numerator, denominator) pair by dividing out the /// common factors between them. The numerator here is an affine expression /// with integer coefficients. The denominator must be positive. -void normalizeDiv(MutableArrayRef num, DynamicAPInt &denom); +void normalizeDiv(MutableArrayRef num, MPInt &denom); /// Return `coeffs` with all the elements negated. -SmallVector getNegatedCoeffs(ArrayRef coeffs); +SmallVector getNegatedCoeffs(ArrayRef coeffs); /// Return the complement of the given inequality. /// /// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is /// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0, /// since all the variables are constrained to be integers. -SmallVector getComplementIneq(ArrayRef ineq); +SmallVector getComplementIneq(ArrayRef ineq); /// Compute the dot product of two vectors. /// The vectors must have the same sizes. diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 7baca03998f5b..235d84c5beff1 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -80,7 +80,6 @@ class TypeSwitch; // Other common classes. class APInt; -class DynamicAPInt; class APSInt; class APFloat; template @@ -144,7 +143,6 @@ using TypeSwitch = llvm::TypeSwitch; using llvm::APFloat; using llvm::APInt; using llvm::APSInt; -using llvm::DynamicAPInt; template using function_ref = llvm::function_ref; using llvm::iterator_range; diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index e628fb152b52f..5c4f353f310d6 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1317,8 +1317,7 @@ mlir::getMultiAffineFunctionFromMap(AffineMap map, "AffineMap cannot produce divs without local representation"); // TODO: We shouldn't have to do this conversion. - Matrix mat(map.getNumResults(), - map.getNumInputs() + divs.getNumDivs() + 1); + Matrix mat(map.getNumResults(), map.getNumInputs() + divs.getNumDivs() + 1); for (unsigned i = 0, e = flattenedExprs.size(); i < e; ++i) for (unsigned j = 0, f = flattenedExprs[i].size(); j < f; ++j) mat(i, j) = flattenedExprs[i][j]; diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp index e16b9269b75fc..0b55185c43782 100644 --- a/mlir/lib/Analysis/Presburger/Barvinok.cpp +++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp @@ -61,9 +61,9 @@ ConeH mlir::presburger::detail::getDual(ConeV cone) { } /// Find the index of a cone in V-representation. -DynamicAPInt mlir::presburger::detail::getIndex(const ConeV &cone) { +MPInt mlir::presburger::detail::getIndex(const ConeV &cone) { if (cone.getNumRows() > cone.getNumColumns()) - return DynamicAPInt(0); + return MPInt(0); return cone.determinant(); } @@ -413,7 +413,7 @@ mlir::presburger::detail::computePolytopeGeneratingFunction( // constant terms zero. ConeH tangentCone = defineHRep(numVars); for (unsigned j = 0, e = subset.getNumRows(); j < e; ++j) { - SmallVector ineq(numVars + 1); + SmallVector ineq(numVars + 1); for (unsigned k = 0; k < numVars; ++k) ineq[k] = subset(j, k); tangentCone.addInequality(ineq); diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt index 1d30dd38ccd1b..83d0514c9e7d1 100644 --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -3,11 +3,13 @@ add_mlir_library(MLIRPresburger IntegerRelation.cpp LinearTransform.cpp Matrix.cpp + MPInt.cpp PresburgerRelation.cpp PresburgerSpace.cpp PWMAFunction.cpp QuasiPolynomial.cpp Simplex.cpp + SlowMPInt.cpp Utils.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 75215fbab5282..b5a2ed6ccc369 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Fraction.h" #include "mlir/Analysis/Presburger/LinearTransform.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" @@ -156,10 +157,9 @@ IntegerRelation::findRationalLexMin() const { return maybeLexMin; } -MaybeOptimum> -IntegerRelation::findIntegerLexMin() const { +MaybeOptimum> IntegerRelation::findIntegerLexMin() const { assert(getNumSymbolVars() == 0 && "Symbols are not supported!"); - MaybeOptimum> maybeLexMin = + MaybeOptimum> maybeLexMin = LexSimplex(*this).findIntegerLexMin(); if (!maybeLexMin.isBounded()) @@ -176,8 +176,8 @@ IntegerRelation::findIntegerLexMin() const { return maybeLexMin; } -static bool rangeIsZero(ArrayRef range) { - return llvm::all_of(range, [](const DynamicAPInt &x) { return x == 0; }); +static bool rangeIsZero(ArrayRef range) { + return llvm::all_of(range, [](const MPInt &x) { return x == 0; }); } static void removeConstraintsInvolvingVarRange(IntegerRelation &poly, @@ -363,14 +363,14 @@ unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) { return insertVar(kind, pos, num); } -void IntegerRelation::addEquality(ArrayRef eq) { +void IntegerRelation::addEquality(ArrayRef eq) { assert(eq.size() == getNumCols()); unsigned row = equalities.appendExtraRow(); for (unsigned i = 0, e = eq.size(); i < e; ++i) equalities(row, i) = eq[i]; } -void IntegerRelation::addInequality(ArrayRef inEq) { +void IntegerRelation::addInequality(ArrayRef inEq) { assert(inEq.size() == getNumCols()); unsigned row = inequalities.appendExtraRow(); for (unsigned i = 0, e = inEq.size(); i < e; ++i) @@ -541,8 +541,7 @@ bool IntegerRelation::hasConsistentState() const { return true; } -void IntegerRelation::setAndEliminate(unsigned pos, - ArrayRef values) { +void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef values) { if (values.empty()) return; assert(pos + values.size() <= getNumVars() && @@ -568,7 +567,7 @@ void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) { bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq, unsigned *rowIdx) const { assert(colIdx < getNumCols() && "position out of bounds"); - auto at = [&](unsigned rowIdx) -> DynamicAPInt { + auto at = [&](unsigned rowIdx) -> MPInt { return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx); }; unsigned e = isEq ? getNumEqualities() : getNumInequalities(); @@ -595,7 +594,7 @@ bool IntegerRelation::hasInvalidConstraint() const { for (unsigned i = 0, e = numRows; i < e; ++i) { unsigned j; for (j = 0; j < numCols - 1; ++j) { - DynamicAPInt v = isEq ? atEq(i, j) : atIneq(i, j); + MPInt v = isEq ? atEq(i, j) : atIneq(i, j); // Skip rows with non-zero variable coefficients. if (v != 0) break; @@ -605,7 +604,7 @@ bool IntegerRelation::hasInvalidConstraint() const { } // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. // Example invalid constraints include: '1 == 0' or '-1 >= 0' - DynamicAPInt v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); + MPInt v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); if ((isEq && v != 0) || (!isEq && v < 0)) { return true; } @@ -627,26 +626,26 @@ static void eliminateFromConstraint(IntegerRelation *constraints, // Skip if equality 'rowIdx' if same as 'pivotRow'. if (isEq && rowIdx == pivotRow) return; - auto at = [&](unsigned i, unsigned j) -> DynamicAPInt { + auto at = [&](unsigned i, unsigned j) -> MPInt { return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); }; - DynamicAPInt leadCoeff = at(rowIdx, pivotCol); + MPInt leadCoeff = at(rowIdx, pivotCol); // Skip if leading coefficient at 'rowIdx' is already zero. if (leadCoeff == 0) return; - DynamicAPInt pivotCoeff = constraints->atEq(pivotRow, pivotCol); + MPInt pivotCoeff = constraints->atEq(pivotRow, pivotCol); int sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; - DynamicAPInt lcm = llvm::lcm(pivotCoeff, leadCoeff); - DynamicAPInt pivotMultiplier = sign * (lcm / abs(pivotCoeff)); - DynamicAPInt rowMultiplier = lcm / abs(leadCoeff); + MPInt lcm = presburger::lcm(pivotCoeff, leadCoeff); + MPInt pivotMultiplier = sign * (lcm / abs(pivotCoeff)); + MPInt rowMultiplier = lcm / abs(leadCoeff); unsigned numCols = constraints->getNumCols(); for (unsigned j = 0; j < numCols; ++j) { // Skip updating column 'j' if it was just eliminated. if (j >= elimColStart && j < pivotCol) continue; - DynamicAPInt v = pivotMultiplier * constraints->atEq(pivotRow, j) + - rowMultiplier * at(rowIdx, j); + MPInt v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); isEq ? constraints->atEq(rowIdx, j) = v : constraints->atIneq(rowIdx, j) = v; } @@ -758,11 +757,11 @@ bool IntegerRelation::isEmptyByGCDTest() const { assert(hasConsistentState()); unsigned numCols = getNumCols(); for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { - DynamicAPInt gcd = abs(atEq(i, 0)); + MPInt gcd = abs(atEq(i, 0)); for (unsigned j = 1; j < numCols - 1; ++j) { - gcd = llvm::gcd(gcd, abs(atEq(i, j))); + gcd = presburger::gcd(gcd, abs(atEq(i, j))); } - DynamicAPInt v = abs(atEq(i, numCols - 1)); + MPInt v = abs(atEq(i, numCols - 1)); if (gcd > 0 && (v % gcd != 0)) { return true; } @@ -865,7 +864,7 @@ bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); } /// /// Concatenating the samples from B and C gives a sample v in S*T, so the /// returned sample T*v is a sample in S. -std::optional> +std::optional> IntegerRelation::findIntegerSample() const { // First, try the GCD test heuristic. if (isEmptyByGCDTest()) @@ -905,7 +904,7 @@ IntegerRelation::findIntegerSample() const { boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars()); // 3) Try to obtain a sample from the bounded set. - std::optional> boundedSample = + std::optional> boundedSample = Simplex(boundedSet).findIntegerSample(); if (!boundedSample) return {}; @@ -944,7 +943,7 @@ IntegerRelation::findIntegerSample() const { // amount for the shrunken cone. for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) { for (unsigned j = 0; j < cone.getNumVars(); ++j) { - DynamicAPInt coeff = cone.atIneq(i, j); + MPInt coeff = cone.atIneq(i, j); if (coeff < 0) cone.atIneq(i, cone.getNumVars()) += coeff; } @@ -961,11 +960,10 @@ IntegerRelation::findIntegerSample() const { SmallVector shrunkenConeSample = *shrunkenConeSimplex.getRationalSample(); - SmallVector coneSample( - llvm::map_range(shrunkenConeSample, ceil)); + SmallVector coneSample(llvm::map_range(shrunkenConeSample, ceil)); // 6) Return transform * concat(boundedSample, coneSample). - SmallVector &sample = *boundedSample; + SmallVector &sample = *boundedSample; sample.append(coneSample.begin(), coneSample.end()); return transform.postMultiplyWithColumn(sample); } @@ -973,11 +971,10 @@ IntegerRelation::findIntegerSample() const { /// Helper to evaluate an affine expression at a point. /// The expression is a list of coefficients for the dimensions followed by the /// constant term. -static DynamicAPInt valueAt(ArrayRef expr, - ArrayRef point) { +static MPInt valueAt(ArrayRef expr, ArrayRef point) { assert(expr.size() == 1 + point.size() && "Dimensionalities of point and expression don't match!"); - DynamicAPInt value = expr.back(); + MPInt value = expr.back(); for (unsigned i = 0; i < point.size(); ++i) value += expr[i] * point[i]; return value; @@ -986,7 +983,7 @@ static DynamicAPInt valueAt(ArrayRef expr, /// A point satisfies an equality iff the value of the equality at the /// expression is zero, and it satisfies an inequality iff the value of the /// inequality at that point is non-negative. -bool IntegerRelation::containsPoint(ArrayRef point) const { +bool IntegerRelation::containsPoint(ArrayRef point) const { for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { if (valueAt(getEquality(i), point) != 0) return false; @@ -1006,8 +1003,8 @@ bool IntegerRelation::containsPoint(ArrayRef point) const { /// compute the values of the locals that have division representations and /// only use the integer emptiness check for the locals that don't have this. /// Handling this correctly requires ordering the divs, though. -std::optional> -IntegerRelation::containsPointNoLocal(ArrayRef point) const { +std::optional> +IntegerRelation::containsPointNoLocal(ArrayRef point) const { assert(point.size() == getNumVars() - getNumLocalVars() && "Point should contain all vars except locals!"); assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() && @@ -1064,7 +1061,7 @@ void IntegerRelation::gcdTightenInequalities() { unsigned numCols = getNumCols(); for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { // Normalize the constraint and tighten the constant term by the GCD. - DynamicAPInt gcd = inequalities.normalizeRow(i, getNumCols() - 1); + MPInt gcd = inequalities.normalizeRow(i, getNumCols() - 1); if (gcd > 1) atIneq(i, numCols - 1) = floorDiv(atIneq(i, numCols - 1), gcd); } @@ -1236,14 +1233,14 @@ void IntegerRelation::removeRedundantConstraints() { equalities.resizeVertically(pos); } -std::optional IntegerRelation::computeVolume() const { +std::optional IntegerRelation::computeVolume() const { assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!"); Simplex simplex(*this); // If the polytope is rationally empty, there are certainly no integer // points. if (simplex.isEmpty()) - return DynamicAPInt(0); + return MPInt(0); // Just find the maximum and minimum integer value of each non-local var // separately, thus finding the number of integer values each such var can @@ -1259,8 +1256,8 @@ std::optional IntegerRelation::computeVolume() const { // // If there is no such empty dimension, if any dimension is unbounded we // just return the result as unbounded. - DynamicAPInt count(1); - SmallVector dim(getNumVars() + 1); + MPInt count(1); + SmallVector dim(getNumVars() + 1); bool hasUnboundedVar = false; for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) { dim[i] = 1; @@ -1280,13 +1277,13 @@ std::optional IntegerRelation::computeVolume() const { // In this case there are no valid integer points and the volume is // definitely zero. if (min.getBoundedOptimum() > max.getBoundedOptimum()) - return DynamicAPInt(0); + return MPInt(0); count *= (*max - *min + 1); } if (count == 0) - return DynamicAPInt(0); + return MPInt(0); if (hasUnboundedVar) return {}; return count; @@ -1479,7 +1476,7 @@ void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart, } void IntegerRelation::addBound(BoundType type, unsigned pos, - const DynamicAPInt &value) { + const MPInt &value) { assert(pos < getNumCols()); if (type == BoundType::EQ) { unsigned row = equalities.appendExtraRow(); @@ -1493,8 +1490,8 @@ void IntegerRelation::addBound(BoundType type, unsigned pos, } } -void IntegerRelation::addBound(BoundType type, ArrayRef expr, - const DynamicAPInt &value) { +void IntegerRelation::addBound(BoundType type, ArrayRef expr, + const MPInt &value) { assert(type != BoundType::EQ && "EQ not implemented"); assert(expr.size() == getNumCols()); unsigned row = inequalities.appendExtraRow(); @@ -1509,15 +1506,15 @@ void IntegerRelation::addBound(BoundType type, ArrayRef expr, /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv. /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, - const DynamicAPInt &divisor) { +void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, + const MPInt &divisor) { assert(dividend.size() == getNumCols() && "incorrect dividend size"); assert(divisor > 0 && "positive divisor expected"); appendVar(VarKind::Local); - SmallVector dividendCopy(dividend.begin(), dividend.end()); - dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0)); + SmallVector dividendCopy(dividend.begin(), dividend.end()); + dividendCopy.insert(dividendCopy.end() - 1, MPInt(0)); addInequality( getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); addInequality( @@ -1533,7 +1530,7 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, bool symbolic = false) { assert(pos < cst.getNumVars() && "invalid position"); for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { - DynamicAPInt v = cst.atEq(r, pos); + MPInt v = cst.atEq(r, pos); if (v * v != 1) continue; unsigned c; @@ -1562,7 +1559,7 @@ LogicalResult IntegerRelation::constantFoldVar(unsigned pos) { // atEq(rowIdx, pos) is either -1 or 1. assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); - DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); + MPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); setAndEliminate(pos, constVal); return success(); } @@ -1588,10 +1585,9 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) { // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = // ceil(s0 - 7 / 8) = floor(s0 / 8)). -std::optional IntegerRelation::getConstantBoundOnDimSize( - unsigned pos, SmallVectorImpl *lb, - DynamicAPInt *boundFloorDivisor, SmallVectorImpl *ub, - unsigned *minLbPos, unsigned *minUbPos) const { +std::optional IntegerRelation::getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl *lb, MPInt *boundFloorDivisor, + SmallVectorImpl *ub, unsigned *minLbPos, unsigned *minUbPos) const { assert(pos < getNumDimVars() && "Invalid variable position"); // Find an equality for 'pos'^th variable that equates it to some function @@ -1603,7 +1599,7 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( // TODO: this can be handled in the future by using the explicit // representation of the local vars. if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1, - [](const DynamicAPInt &coeff) { return coeff == 0; })) + [](const MPInt &coeff) { return coeff == 0; })) return std::nullopt; // This variable can only take a single value. @@ -1613,7 +1609,7 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( if (ub) ub->resize(getNumSymbolVars() + 1); for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) { - DynamicAPInt v = atEq(eqPos, pos); + MPInt v = atEq(eqPos, pos); // atEq(eqRow, pos) is either -1 or 1. assert(v * v == 1); (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v @@ -1630,7 +1626,7 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( *minLbPos = eqPos; if (minUbPos) *minUbPos = eqPos; - return DynamicAPInt(1); + return MPInt(1); } // Check if the variable appears at all in any of the inequalities. @@ -1654,7 +1650,7 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( /*eqIndices=*/nullptr, /*offset=*/0, /*num=*/getNumDimVars()); - std::optional minDiff; + std::optional minDiff; unsigned minLbPosition = 0, minUbPosition = 0; for (auto ubPos : ubIndices) { for (auto lbPos : lbIndices) { @@ -1671,11 +1667,11 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( } if (j < getNumCols() - 1) continue; - DynamicAPInt diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + - atIneq(lbPos, getNumCols() - 1) + 1, - atIneq(lbPos, pos)); + MPInt diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); // This bound is non-negative by definition. - diff = std::max(diff, DynamicAPInt(0)); + diff = std::max(diff, MPInt(0)); if (minDiff == std::nullopt || diff < minDiff) { minDiff = diff; minLbPosition = lbPos; @@ -1715,7 +1711,7 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( } template -std::optional +std::optional IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. @@ -1737,7 +1733,7 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { // If it doesn't, there isn't a bound on it. return std::nullopt; - std::optional minOrMaxConst; + std::optional minOrMaxConst; // Take the max across all const lower bounds (or min across all constant // upper bounds). @@ -1758,7 +1754,7 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { // Not a constant bound. continue; - DynamicAPInt boundConst = + MPInt boundConst = isLower ? ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) : floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); if (isLower) { @@ -1772,8 +1768,8 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { return minOrMaxConst; } -std::optional -IntegerRelation::getConstantBound(BoundType type, unsigned pos) const { +std::optional IntegerRelation::getConstantBound(BoundType type, + unsigned pos) const { if (type == BoundType::LB) return IntegerRelation(*this) .computeConstantLowerOrUpperBound(pos); @@ -1782,14 +1778,13 @@ IntegerRelation::getConstantBound(BoundType type, unsigned pos) const { .computeConstantLowerOrUpperBound(pos); assert(type == BoundType::EQ && "expected EQ"); - std::optional lb = + std::optional lb = IntegerRelation(*this).computeConstantLowerOrUpperBound( pos); - std::optional ub = + std::optional ub = IntegerRelation(*this) .computeConstantLowerOrUpperBound(pos); - return (lb && ub && *lb == *ub) ? std::optional(*ub) - : std::nullopt; + return (lb && ub && *lb == *ub) ? std::optional(*ub) : std::nullopt; } // A simple (naive and conservative) check for hyper-rectangularity. @@ -1830,10 +1825,10 @@ void IntegerRelation::removeTrivialRedundancy() { // A map used to detect redundancy stemming from constraints that only differ // in their constant term. The value stored is // for a given row. - SmallDenseMap, std::pair> + SmallDenseMap, std::pair> rowsWithoutConstTerm; // To unique rows. - SmallDenseSet, 8> rowSet; + SmallDenseSet, 8> rowSet; // Check if constraint is of the form >= 0. auto isTriviallyValid = [&](unsigned r) -> bool { @@ -1847,8 +1842,8 @@ void IntegerRelation::removeTrivialRedundancy() { // Detect and mark redundant constraints. SmallVector redunIneq(getNumInequalities(), false); for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { - DynamicAPInt *rowStart = &inequalities(r, 0); - auto row = ArrayRef(rowStart, getNumCols()); + MPInt *rowStart = &inequalities(r, 0); + auto row = ArrayRef(rowStart, getNumCols()); if (isTriviallyValid(r) || !rowSet.insert(row).second) { redunIneq[r] = true; continue; @@ -1858,9 +1853,8 @@ void IntegerRelation::removeTrivialRedundancy() { // everything other than the one with the smallest constant term redundant. // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the // former two are redundant). - DynamicAPInt constTerm = atIneq(r, getNumCols() - 1); - auto rowWithoutConstTerm = - ArrayRef(rowStart, getNumCols() - 1); + MPInt constTerm = atIneq(r, getNumCols() - 1); + auto rowWithoutConstTerm = ArrayRef(rowStart, getNumCols() - 1); const auto &ret = rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); if (!ret.second) { @@ -2016,19 +2010,19 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, // integer exact. for (auto ubPos : ubIndices) { for (auto lbPos : lbIndices) { - SmallVector ineq; + SmallVector ineq; ineq.reserve(newRel.getNumCols()); - DynamicAPInt lbCoeff = atIneq(lbPos, pos); + MPInt lbCoeff = atIneq(lbPos, pos); // Note that in the comments above, ubCoeff is the negation of the // coefficient in the canonical form as the view taken here is that of the // term being moved to the other size of '>='. - DynamicAPInt ubCoeff = -atIneq(ubPos, pos); + MPInt ubCoeff = -atIneq(ubPos, pos); // TODO: refactor this loop to avoid all branches inside. for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) continue; assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); - DynamicAPInt lcm = llvm::lcm(lbCoeff, ubCoeff); + MPInt lcm = presburger::lcm(lbCoeff, ubCoeff); ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + atIneq(lbPos, l) * (lcm / lbCoeff)); assert(lcm > 0 && "lcm should be positive!"); @@ -2053,7 +2047,7 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, // Copy over the constraints not involving this variable. for (auto nbPos : nbIndices) { - SmallVector ineq; + SmallVector ineq; ineq.reserve(getNumCols() - 1); for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) @@ -2068,7 +2062,7 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, // Copy over the equalities. for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { - SmallVector eq; + SmallVector eq; eq.reserve(newRel.getNumCols()); for (unsigned l = 0, e = getNumCols(); l < e; l++) { if (l == pos) @@ -2132,8 +2126,7 @@ enum BoundCmpResult { Greater, Less, Equal, Unknown }; /// Compares two affine bounds whose coefficients are provided in 'first' and /// 'second'. The last coefficient is the constant term. -static BoundCmpResult compareBounds(ArrayRef a, - ArrayRef b) { +static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { assert(a.size() == b.size()); // For the bounds to be comparable, their corresponding variable @@ -2185,20 +2178,20 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { IntegerRelation commonCst(PresburgerSpace::getRelationSpace()); getCommonConstraints(*this, otherCst, commonCst); - std::vector> boundingLbs; - std::vector> boundingUbs; + std::vector> boundingLbs; + std::vector> boundingUbs; boundingLbs.reserve(2 * getNumDimVars()); boundingUbs.reserve(2 * getNumDimVars()); // To hold lower and upper bounds for each dimension. - SmallVector lb, otherLb, ub, otherUb; + SmallVector lb, otherLb, ub, otherUb; // To compute min of lower bounds and max of upper bounds for each dimension. - SmallVector minLb(getNumSymbolVars() + 1); - SmallVector maxUb(getNumSymbolVars() + 1); + SmallVector minLb(getNumSymbolVars() + 1); + SmallVector maxUb(getNumSymbolVars() + 1); // To compute final new lower and upper bounds for the union. - SmallVector newLb(getNumCols()), newUb(getNumCols()); + SmallVector newLb(getNumCols()), newUb(getNumCols()); - DynamicAPInt lbFloorDivisor, otherLbFloorDivisor; + MPInt lbFloorDivisor, otherLbFloorDivisor; for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) { auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); if (!extent.has_value()) @@ -2261,8 +2254,7 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { // Copy over the symbolic part + constant term. std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars()); std::transform(newLb.begin() + getNumDimVars(), newLb.end(), - newLb.begin() + getNumDimVars(), - std::negate()); + newLb.begin() + getNumDimVars(), std::negate()); std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars()); boundingLbs.push_back(newLb); @@ -2357,14 +2349,14 @@ IntegerPolyhedron IntegerRelation::getDomainSet() const { bool IntegerRelation::removeDuplicateConstraints() { bool changed = false; - SmallDenseMap, unsigned> hashTable; + SmallDenseMap, unsigned> hashTable; unsigned ineqs = getNumInequalities(), cols = getNumCols(); if (ineqs <= 1) return changed; // Check if the non-constant part of the constraint is the same. - ArrayRef row = getInequality(0).drop_back(); + ArrayRef row = getInequality(0).drop_back(); hashTable.insert({row, 0}); for (unsigned k = 1; k < ineqs; ++k) { row = getInequality(k).drop_back(); @@ -2384,11 +2376,11 @@ bool IntegerRelation::removeDuplicateConstraints() { } // Check the neg form of each inequality, need an extra vector to store it. - SmallVector negIneq(cols - 1); + SmallVector negIneq(cols - 1); for (unsigned k = 0; k < ineqs; ++k) { row = getInequality(k).drop_back(); negIneq.assign(row.begin(), row.end()); - for (DynamicAPInt &ele : negIneq) + for (MPInt &ele : negIneq) ele = -ele; if (!hashTable.contains(negIneq)) continue; diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp index ecab634967694..3e080e698b199 100644 --- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp +++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/LinearTransform.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Support/LLVM.h" #include @@ -47,21 +48,21 @@ IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const { IntegerRelation result(rel.getSpace()); for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) { - ArrayRef eq = rel.getEquality(i); + ArrayRef eq = rel.getEquality(i); - const DynamicAPInt &c = eq.back(); + const MPInt &c = eq.back(); - SmallVector newEq = preMultiplyWithRow(eq.drop_back()); + SmallVector newEq = preMultiplyWithRow(eq.drop_back()); newEq.push_back(c); result.addEquality(newEq); } for (unsigned i = 0, e = rel.getNumInequalities(); i < e; ++i) { - ArrayRef ineq = rel.getInequality(i); + ArrayRef ineq = rel.getInequality(i); - const DynamicAPInt &c = ineq.back(); + const MPInt &c = ineq.back(); - SmallVector newIneq = preMultiplyWithRow(ineq.drop_back()); + SmallVector newIneq = preMultiplyWithRow(ineq.drop_back()); newIneq.push_back(c); result.addInequality(newIneq); } diff --git a/mlir/lib/Analysis/Presburger/MPInt.cpp b/mlir/lib/Analysis/Presburger/MPInt.cpp new file mode 100644 index 0000000000000..587e2b572facf --- /dev/null +++ b/mlir/lib/Analysis/Presburger/MPInt.cpp @@ -0,0 +1,38 @@ +//===- MPInt.cpp - MLIR MPInt Class ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace presburger; + +llvm::hash_code mlir::presburger::hash_value(const MPInt &x) { + if (x.isSmall()) + return llvm::hash_value(x.getSmall()); + return detail::hash_value(x.getLarge()); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +llvm::raw_ostream &MPInt::print(llvm::raw_ostream &os) const { + if (isSmall()) + return os << valSmall; + return os << valLarge; +} + +void MPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &mlir::presburger::operator<<(llvm::raw_ostream &os, + const MPInt &x) { + x.print(os); + return os; +} diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 134b805648d9f..4cb6e6b16bc87 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Analysis/Presburger/Fraction.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/MathExtras.h" @@ -371,12 +372,12 @@ SmallVector Matrix::postMultiplyWithColumn(ArrayRef colVec) const { /// sourceCol. This brings M(row, targetCol) to the range [0, M(row, /// sourceCol)). Apply the same column operation to otherMatrix, with the same /// integer multiple. -static void modEntryColumnOperation(Matrix &m, unsigned row, +static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol, unsigned targetCol, - Matrix &otherMatrix) { + Matrix &otherMatrix) { assert(m(row, sourceCol) != 0 && "Cannot divide by zero!"); assert(m(row, sourceCol) > 0 && "Source must be positive!"); - DynamicAPInt ratio = -floorDiv(m(row, targetCol), m(row, sourceCol)); + MPInt ratio = -floorDiv(m(row, targetCol), m(row, sourceCol)); m.addToColumn(sourceCol, targetCol, ratio); otherMatrix.addToColumn(sourceCol, targetCol, ratio); } @@ -443,7 +444,7 @@ bool Matrix::hasConsistentState() const { namespace mlir { namespace presburger { -template class Matrix; +template class Matrix; template class Matrix; } // namespace presburger } // namespace mlir @@ -541,25 +542,25 @@ std::pair IntMatrix::computeHermiteNormalForm() const { return {h, u}; } -DynamicAPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) { +MPInt IntMatrix::normalizeRow(unsigned row, unsigned cols) { return normalizeRange(getRow(row).slice(0, cols)); } -DynamicAPInt IntMatrix::normalizeRow(unsigned row) { +MPInt IntMatrix::normalizeRow(unsigned row) { return normalizeRow(row, getNumColumns()); } -DynamicAPInt IntMatrix::determinant(IntMatrix *inverse) const { +MPInt IntMatrix::determinant(IntMatrix *inverse) const { assert(nRows == nColumns && "determinant can only be calculated for square matrices!"); FracMatrix m(*this); FracMatrix fracInverse(nRows, nColumns); - DynamicAPInt detM = m.determinant(&fracInverse).getAsInteger(); + MPInt detM = m.determinant(&fracInverse).getAsInteger(); if (detM == 0) - return DynamicAPInt(0); + return MPInt(0); if (!inverse) return detM; @@ -717,7 +718,7 @@ FracMatrix FracMatrix::gramSchmidt() const { // // We repeat this until k = n and return. void FracMatrix::LLL(Fraction delta) { - DynamicAPInt nearest; + MPInt nearest; Fraction mu; // `gsOrth` holds the Gram-Schmidt orthogonalisation @@ -761,7 +762,7 @@ IntMatrix FracMatrix::normalizeRows() const { unsigned numColumns = getNumColumns(); IntMatrix normalized(numRows, numColumns); - DynamicAPInt lcmDenoms = DynamicAPInt(1); + MPInt lcmDenoms = MPInt(1); for (unsigned i = 0; i < numRows; i++) { // For a row, first compute the LCM of the denominators. for (unsigned j = 0; j < numColumns; j++) diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp index 664670d506d53..d55962616de17 100644 --- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp +++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Utils.h" @@ -40,11 +41,11 @@ void MultiAffineFunction::assertIsConsistent() const { // Return the result of subtracting the two given vectors pointwise. // The vectors must be of the same size. // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. -static SmallVector subtractExprs(ArrayRef vecA, - ArrayRef vecB) { +static SmallVector subtractExprs(ArrayRef vecA, + ArrayRef vecB) { assert(vecA.size() == vecB.size() && "Cannot subtract vectors of differing lengths!"); - SmallVector result; + SmallVector result; result.reserve(vecA.size()); for (unsigned i = 0, e = vecA.size(); i < e; ++i) result.push_back(vecA[i] - vecB[i]); @@ -66,19 +67,18 @@ void MultiAffineFunction::print(raw_ostream &os) const { output.print(os); } -SmallVector -MultiAffineFunction::valueAt(ArrayRef point) const { +SmallVector +MultiAffineFunction::valueAt(ArrayRef point) const { assert(point.size() == getNumDomainVars() + getNumSymbolVars() && "Point has incorrect dimensionality!"); - SmallVector pointHomogenous{llvm::to_vector(point)}; + SmallVector pointHomogenous{llvm::to_vector(point)}; // Get the division values at this point. - SmallVector, 8> divValues = - divs.divValuesAt(point); + SmallVector, 8> divValues = divs.divValuesAt(point); // The given point didn't include the values of the divs which the output is a // function of; we have computed one possible set of values and use them here. pointHomogenous.reserve(pointHomogenous.size() + divValues.size()); - for (const std::optional &divVal : divValues) + for (const std::optional &divVal : divValues) pointHomogenous.push_back(*divVal); // The matrix `output` has an affine expression in the ith row, corresponding // to the expression for the ith value in the output vector. The last column @@ -86,8 +86,7 @@ MultiAffineFunction::valueAt(ArrayRef point) const { // a 1 appended at the end. We can see that output * v gives the desired // output vector. pointHomogenous.emplace_back(1); - SmallVector result = - output.postMultiplyWithColumn(pointHomogenous); + SmallVector result = output.postMultiplyWithColumn(pointHomogenous); assert(result.size() == getNumOutputs()); return result; } @@ -139,7 +138,7 @@ void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { other.divs.insertDiv(0, nDivs); - SmallVector div(other.divs.getNumVars() + 1); + SmallVector div(other.divs.getNumVars() + 1); for (unsigned i = 0; i < nDivs; ++i) { // Zero fill. std::fill(div.begin(), div.end(), 0); @@ -233,7 +232,7 @@ MultiAffineFunction::getLexSet(OrderingKind comp, for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { // Create the expression `outA - outB` for this level. - SmallVector subExpr = + SmallVector subExpr = subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level)); // TODO: Implement all comparison cases. @@ -243,14 +242,14 @@ MultiAffineFunction::getLexSet(OrderingKind comp, // outA - outB <= -1 // outA <= outB - 1 // outA < outB - levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1)); + levelSet.addBound(BoundType::UB, subExpr, MPInt(-1)); break; case OrderingKind::GT: // For greater than, we add a lower bound of 1: // outA - outB >= 1 // outA > outB + 1 // outA > outB - levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1)); + levelSet.addBound(BoundType::LB, subExpr, MPInt(1)); break; case OrderingKind::GE: case OrderingKind::LE: @@ -390,7 +389,7 @@ void MultiAffineFunction::subtract(const MultiAffineFunction &other) { MultiAffineFunction copyOther = other; mergeDivs(copyOther); for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) - output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1)); + output.addToRow(i, copyOther.getOutputExpr(i), MPInt(-1)); // Check consistency. assertIsConsistent(); @@ -430,14 +429,14 @@ IntegerRelation MultiAffineFunction::getAsRelation() const { // Add equalities such that the i^th range variable is equal to the i^th // output expression. - SmallVector eq(result.getNumCols()); + SmallVector eq(result.getNumCols()); for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { // TODO: Add functions to get VarKind offsets in output in MAF and use them // here. // The output expression does not contain range variables, while the // equality does. So, we need to copy all variables and mark all range // variables as 0 in the equality. - ArrayRef expr = getOutputExpr(i); + ArrayRef expr = getOutputExpr(i); // Copy domain variables in `expr` to domain variables in `eq`. std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin()); // Fill the range variables in `eq` as zero. @@ -463,8 +462,8 @@ void PWMAFunction::removeOutputs(unsigned start, unsigned end) { piece.output.removeOutputs(start, end); } -std::optional> -PWMAFunction::valueAt(ArrayRef point) const { +std::optional> +PWMAFunction::valueAt(ArrayRef point) const { assert(point.size() == getNumDomainVars() + getNumSymbolVars()); for (const Piece &piece : pieces) diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 6173f774d0475..3af6baae0e700 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Simplex.h" @@ -121,7 +122,7 @@ PresburgerRelation::unionSet(const PresburgerRelation &set) const { } /// A point is contained in the union iff any of the parts contain the point. -bool PresburgerRelation::containsPoint(ArrayRef point) const { +bool PresburgerRelation::containsPoint(ArrayRef point) const { return llvm::any_of(disjuncts, [&](const IntegerRelation &disjunct) { return (disjunct.containsPointNoLocal(point)); }); @@ -285,15 +286,15 @@ SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const { /// /// For every eq `coeffs == 0` there are two possible ineqs to index into. /// The first is coeffs >= 0 and the second is coeffs <= 0. -static SmallVector -getIneqCoeffsFromIdx(const IntegerRelation &rel, unsigned idx) { +static SmallVector getIneqCoeffsFromIdx(const IntegerRelation &rel, + unsigned idx) { assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() && "idx out of bounds!"); if (idx < rel.getNumInequalities()) return llvm::to_vector<8>(rel.getInequality(idx)); idx -= rel.getNumInequalities(); - ArrayRef eqCoeffs = rel.getEquality(idx / 2); + ArrayRef eqCoeffs = rel.getEquality(idx / 2); if (idx % 2 == 0) return llvm::to_vector<8>(eqCoeffs); @@ -553,7 +554,7 @@ static PresburgerRelation getSetDifference(IntegerRelation b, // state before adding this complement constraint, and add s_ij to b. simplex.rollback(frame.simplexSnapshot); b.truncate(frame.bCounts); - SmallVector ineq = + SmallVector ineq = getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed); b.addInequality(ineq); simplex.addInequality(ineq); @@ -571,7 +572,7 @@ static PresburgerRelation getSetDifference(IntegerRelation b, frame.simplexSnapshot = simplex.getSnapshot(); unsigned idx = frame.ineqsToProcess.back(); - SmallVector ineq = + SmallVector ineq = getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx)); b.addInequality(ineq); simplex.addInequality(ineq); @@ -669,11 +670,10 @@ bool PresburgerRelation::isIntegerEmpty() const { return llvm::all_of(disjuncts, std::mem_fn(&IntegerRelation::isIntegerEmpty)); } -bool PresburgerRelation::findIntegerSample( - SmallVectorImpl &sample) { +bool PresburgerRelation::findIntegerSample(SmallVectorImpl &sample) { // A sample exists iff any of the disjuncts contains a sample. for (const IntegerRelation &disjunct : disjuncts) { - if (std::optional> opt = + if (std::optional> opt = disjunct.findIntegerSample()) { sample = std::move(*opt); return true; @@ -682,13 +682,13 @@ bool PresburgerRelation::findIntegerSample( return false; } -std::optional PresburgerRelation::computeVolume() const { +std::optional PresburgerRelation::computeVolume() const { assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!"); // The sum of the volumes of the disjuncts is a valid overapproximation of the // volume of their union, even if they overlap. - DynamicAPInt result(0); + MPInt result(0); for (const IntegerRelation &disjunct : disjuncts) { - std::optional volume = disjunct.computeVolume(); + std::optional volume = disjunct.computeVolume(); if (!volume) return {}; result += *volume; @@ -723,20 +723,20 @@ class presburger::SetCoalescer { /// The list of all inversed equalities during typing. This ensures that /// the constraints exist even after the typing function has concluded. - SmallVector, 2> negEqs; + SmallVector, 2> negEqs; /// `redundantIneqsA` is the inequalities of `a` that are redundant for `b` /// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`). - SmallVector, 2> redundantIneqsA; - SmallVector, 2> cuttingIneqsA; + SmallVector, 2> redundantIneqsA; + SmallVector, 2> cuttingIneqsA; - SmallVector, 2> redundantIneqsB; - SmallVector, 2> cuttingIneqsB; + SmallVector, 2> redundantIneqsB; + SmallVector, 2> cuttingIneqsB; /// Given a Simplex `simp` and one of its inequalities `ineq`, check /// that the facet of `simp` where `ineq` holds as an equality is contained /// within `a`. - bool isFacetContained(ArrayRef ineq, Simplex &simp); + bool isFacetContained(ArrayRef ineq, Simplex &simp); /// Removes redundant constraints from `disjunct`, adds it to `disjuncts` and /// removes the disjuncts at position `i` and `j`. Updates `simplices` to @@ -760,13 +760,13 @@ class presburger::SetCoalescer { /// Types the inequality `ineq` according to its `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); + LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and /// -`eq` >= 0 according to their `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - LogicalResult typeEquality(ArrayRef eq, Simplex &simp); + LogicalResult typeEquality(ArrayRef eq, Simplex &simp); /// Replaces the element at position `i` with the last element and erases /// the last element for both `disjuncts` and `simplices`. @@ -843,11 +843,10 @@ PresburgerRelation SetCoalescer::coalesce() { /// Given a Simplex `simp` and one of its inequalities `ineq`, check /// that all inequalities of `cuttingIneqsB` are redundant for the facet of /// `simp` where `ineq` holds as an equality is contained within `a`. -bool SetCoalescer::isFacetContained(ArrayRef ineq, - Simplex &simp) { +bool SetCoalescer::isFacetContained(ArrayRef ineq, Simplex &simp) { SimplexRollbackScopeExit scopeExit(simp); simp.addEquality(ineq); - return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef curr) { + return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef curr) { return simp.isRedundantInequality(curr); }); } @@ -909,23 +908,23 @@ LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { /// redundant ones are, so only the cutting ones remain to be checked. Simplex &simp = simplices[i]; IntegerRelation &disjunct = disjuncts[i]; - if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { + if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { return !isFacetContained(curr, simp); })) return failure(); IntegerRelation newSet(disjunct.getSpace()); - for (ArrayRef curr : redundantIneqsA) + for (ArrayRef curr : redundantIneqsA) newSet.addInequality(curr); - for (ArrayRef curr : redundantIneqsB) + for (ArrayRef curr : redundantIneqsB) newSet.addInequality(curr); addCoalescedDisjunct(i, j, newSet); return success(); } -LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, +LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, Simplex &simp) { Simplex::IneqType type = simp.findIneqType(ineq); if (type == Simplex::IneqType::Redundant) @@ -937,12 +936,11 @@ LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, return success(); } -LogicalResult SetCoalescer::typeEquality(ArrayRef eq, - Simplex &simp) { +LogicalResult SetCoalescer::typeEquality(ArrayRef eq, Simplex &simp) { if (typeInequality(eq, simp).failed()) return failure(); negEqs.push_back(getNegatedCoeffs(eq)); - ArrayRef inv(negEqs.back()); + ArrayRef inv(negEqs.back()); if (typeInequality(inv, simp).failed()) return failure(); return success(); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 2cdd79d42732d..1969cce93ad2e 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -9,6 +9,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Fraction.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Utils.h" @@ -36,11 +37,10 @@ const int nullIndex = std::numeric_limits::max(); // Return a + scale*b; LLVM_ATTRIBUTE_UNUSED -static SmallVector -scaleAndAddForAssert(ArrayRef a, const DynamicAPInt &scale, - ArrayRef b) { +static SmallVector +scaleAndAddForAssert(ArrayRef a, const MPInt &scale, ArrayRef b) { assert(a.size() == b.size()); - SmallVector res; + SmallVector res; res.reserve(a.size()); for (unsigned i = 0, e = a.size(); i < e; ++i) res.push_back(a[i] + scale * b[i]); @@ -116,8 +116,7 @@ unsigned SimplexBase::addZeroRow(bool makeRestricted) { /// Add a new row to the tableau corresponding to the given constant term and /// list of coefficients. The coefficients are specified as a vector of /// (variable index, coefficient) pairs. -unsigned SimplexBase::addRow(ArrayRef coeffs, - bool makeRestricted) { +unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { assert(coeffs.size() == var.size() + 1 && "Incorrect number of coefficients!"); assert(var.size() + getNumFixedCols() == getNumColumns() && @@ -140,7 +139,7 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, // // Symbols don't use the big M parameter since they do not get lex // optimized. - DynamicAPInt bigMCoeff(0); + MPInt bigMCoeff(0); for (unsigned i = 0; i < coeffs.size() - 1; ++i) if (!var[i].isSymbol) bigMCoeff -= coeffs[i]; @@ -166,9 +165,9 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, // row, scaled by the coefficient for the variable, accounting for the two // rows potentially having different denominators. The new denominator is // the lcm of the two. - DynamicAPInt lcm = llvm::lcm(tableau(newRow, 0), tableau(pos, 0)); - DynamicAPInt nRowCoeff = lcm / tableau(newRow, 0); - DynamicAPInt idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0)); + MPInt lcm = presburger::lcm(tableau(newRow, 0), tableau(pos, 0)); + MPInt nRowCoeff = lcm / tableau(newRow, 0); + MPInt idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0)); tableau(newRow, 0) = lcm; for (unsigned col = 1, e = getNumColumns(); col < e; ++col) tableau(newRow, col) = @@ -181,7 +180,7 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, } namespace { -bool signMatchesDirection(const DynamicAPInt &elem, Direction direction) { +bool signMatchesDirection(const MPInt &elem, Direction direction) { assert(elem != 0 && "elem should not be 0"); return direction == Direction::Up ? elem > 0 : elem < 0; } @@ -277,7 +276,7 @@ MaybeOptimum> LexSimplex::findRationalLexMin() { /// The constraint is violated when added (it would be useless otherwise) /// so we immediately try to move it to a column. LogicalResult LexSimplexBase::addCut(unsigned row) { - DynamicAPInt d = tableau(row, 0); + MPInt d = tableau(row, 0); unsigned cutRow = addZeroRow(/*makeRestricted=*/true); tableau(cutRow, 0) = d; tableau(cutRow, 1) = -mod(-tableau(row, 1), d); // -c%d. @@ -301,7 +300,7 @@ std::optional LexSimplex::maybeGetNonIntegralVarRow() const { return {}; } -MaybeOptimum> LexSimplex::findIntegerLexMin() { +MaybeOptimum> LexSimplex::findIntegerLexMin() { // We first try to make the tableau consistent. if (restoreRationalConsistency().failed()) return OptimumKind::Empty; @@ -332,19 +331,19 @@ MaybeOptimum> LexSimplex::findIntegerLexMin() { llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger))); } -bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { +bool LexSimplex::isSeparateInequality(ArrayRef coeffs) { SimplexRollbackScopeExit scopeExit(*this); addInequality(coeffs); return findIntegerLexMin().isEmpty(); } -bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { +bool LexSimplex::isRedundantInequality(ArrayRef coeffs) { return isSeparateInequality(getComplementIneq(coeffs)); } -SmallVector +SmallVector SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const { - SmallVector sample; + SmallVector sample; sample.reserve(nSymbol + 1); for (unsigned col = 3; col < 3 + nSymbol; ++col) sample.push_back(tableau(row, col)); @@ -352,9 +351,9 @@ SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const { return sample; } -SmallVector +SmallVector SymbolicLexSimplex::getSymbolicSampleIneq(unsigned row) const { - SmallVector sample = getSymbolicSampleNumerator(row); + SmallVector sample = getSymbolicSampleNumerator(row); // The inequality is equivalent to the GCD-normalized one. normalizeRange(sample); return sample; @@ -367,15 +366,14 @@ void LexSimplexBase::appendSymbol() { nSymbol++; } -static bool isRangeDivisibleBy(ArrayRef range, - const DynamicAPInt &divisor) { +static bool isRangeDivisibleBy(ArrayRef range, const MPInt &divisor) { assert(divisor > 0 && "divisor must be positive!"); - return llvm::all_of( - range, [divisor](const DynamicAPInt &x) { return x % divisor == 0; }); + return llvm::all_of(range, + [divisor](const MPInt &x) { return x % divisor == 0; }); } bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { - DynamicAPInt denom = tableau(row, 0); + MPInt denom = tableau(row, 0); return tableau(row, 1) % denom == 0 && isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom); } @@ -414,7 +412,7 @@ bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { /// This constraint is violated when added so we immediately try to move it to a /// column. LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { - DynamicAPInt d = tableau(row, 0); + MPInt d = tableau(row, 0); if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) { // The coefficients of symbols in the symbol numerator are divisible // by the denominator, so we can add the constraint directly, @@ -423,9 +421,9 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { } // Construct the division variable `q = ((-c%d) + sum_i (-a_i%d)s_i)/d`. - SmallVector divCoeffs; + SmallVector divCoeffs; divCoeffs.reserve(nSymbol + 1); - DynamicAPInt divDenom = d; + MPInt divDenom = d; for (unsigned col = 3; col < 3 + nSymbol; ++col) divCoeffs.push_back(mod(-tableau(row, col), divDenom)); // (-a_i%d)s_i divCoeffs.push_back(mod(-tableau(row, 1), divDenom)); // -c%d. @@ -466,7 +464,7 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const { return; } - DynamicAPInt denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); if (tableau(u.pos, 2) < denom) { // M + u has a sample value of fM + something, where f < 1, so // u = (f - 1)M + something, which has a negative coefficient for M, @@ -477,8 +475,8 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const { assert(tableau(u.pos, 2) == denom && "Coefficient of M should not be greater than 1!"); - SmallVector sample = getSymbolicSampleNumerator(u.pos); - for (DynamicAPInt &elem : sample) { + SmallVector sample = getSymbolicSampleNumerator(u.pos); + for (MPInt &elem : sample) { assert(elem % denom == 0 && "coefficients must be integral!"); elem /= denom; } @@ -575,7 +573,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { continue; } - SmallVector symbolicSample; + SmallVector symbolicSample; unsigned splitRow = 0; for (unsigned e = getNumRows(); splitRow < e; ++splitRow) { if (tableau(splitRow, 2) > 0) @@ -660,7 +658,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { // was negative. assert(u.orientation == Orientation::Row && "The split row should have been returned to row orientation!"); - SmallVector splitIneq = + SmallVector splitIneq = getComplementIneq(getSymbolicSampleIneq(u.pos)); normalizeRange(splitIneq); if (moveRowUnknownToColumn(u.pos).failed()) { @@ -836,7 +834,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, // all possible values of the symbols. auto getSampleChangeCoeffForVar = [this, row](unsigned col, const Unknown &u) -> Fraction { - DynamicAPInt a = tableau(row, col); + MPInt a = tableau(row, col); if (u.orientation == Orientation::Column) { // Pivot column case. if (u.pos == col) @@ -851,7 +849,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, return {1, 1}; // Non-pivot row case. - DynamicAPInt c = tableau(u.pos, col); + MPInt c = tableau(u.pos, col); return {c, a}; }; @@ -885,7 +883,7 @@ std::optional Simplex::findPivot(int row, Direction direction) const { std::optional col; for (unsigned j = 2, e = getNumColumns(); j < e; ++j) { - DynamicAPInt elem = tableau(row, j); + MPInt elem = tableau(row, j); if (elem == 0) continue; @@ -1034,18 +1032,18 @@ std::optional Simplex::findPivotRow(std::optional skipRow, // retConst being used uninitialized in the initialization of `diff` below. In // reality, these are always initialized when that line is reached since these // are set whenever retRow is set. - DynamicAPInt retElem, retConst; + MPInt retElem, retConst; for (unsigned row = nRedundant, e = getNumRows(); row < e; ++row) { if (skipRow && row == *skipRow) continue; - DynamicAPInt elem = tableau(row, col); + MPInt elem = tableau(row, col); if (elem == 0) continue; if (!unknownFromRow(row).restricted) continue; if (signMatchesDirection(elem, direction)) continue; - DynamicAPInt constTerm = tableau(row, 1); + MPInt constTerm = tableau(row, 1); if (!retRow) { retRow = row; @@ -1054,7 +1052,7 @@ std::optional Simplex::findPivotRow(std::optional skipRow, continue; } - DynamicAPInt diff = retConst * elem - constTerm * retElem; + MPInt diff = retConst * elem - constTerm * retElem; if ((diff == 0 && rowUnknown[row] < rowUnknown[*retRow]) || (diff != 0 && !signMatchesDirection(diff, direction))) { retRow = row; @@ -1105,7 +1103,7 @@ void SimplexBase::markEmpty() { /// We add the inequality and mark it as restricted. We then try to make its /// sample value non-negative. If this is not possible, the tableau has become /// empty and we mark it as such. -void Simplex::addInequality(ArrayRef coeffs) { +void Simplex::addInequality(ArrayRef coeffs) { unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true); LogicalResult result = restoreRow(con[conIndex]); if (failed(result)) @@ -1118,10 +1116,10 @@ void Simplex::addInequality(ArrayRef coeffs) { /// /// We simply add two opposing inequalities, which force the expression to /// be zero. -void SimplexBase::addEquality(ArrayRef coeffs) { +void SimplexBase::addEquality(ArrayRef coeffs) { addInequality(coeffs); - SmallVector negatedCoeffs; - for (const DynamicAPInt &coeff : coeffs) + SmallVector negatedCoeffs; + for (const MPInt &coeff : coeffs) negatedCoeffs.emplace_back(-coeff); addInequality(negatedCoeffs); } @@ -1297,18 +1295,18 @@ void SimplexBase::rollback(unsigned snapshot) { /// /// This constrains the remainder `coeffs - denom*q` to be in the /// range `[0, denom - 1]`, which fixes the integer value of the quotient `q`. -void SimplexBase::addDivisionVariable(ArrayRef coeffs, - const DynamicAPInt &denom) { +void SimplexBase::addDivisionVariable(ArrayRef coeffs, + const MPInt &denom) { assert(denom > 0 && "Denominator must be positive!"); appendVariable(); - SmallVector ineq(coeffs.begin(), coeffs.end()); - DynamicAPInt constTerm = ineq.back(); + SmallVector ineq(coeffs.begin(), coeffs.end()); + MPInt constTerm = ineq.back(); ineq.back() = -denom; ineq.push_back(constTerm); addInequality(ineq); - for (DynamicAPInt &coeff : ineq) + for (MPInt &coeff : ineq) coeff = -coeff; ineq.back() += denom - 1; addInequality(ineq); @@ -1358,7 +1356,7 @@ MaybeOptimum Simplex::computeRowOptimum(Direction direction, /// Compute the optimum of the specified expression in the specified direction, /// or std::nullopt if it is unbounded. MaybeOptimum Simplex::computeOptimum(Direction direction, - ArrayRef coeffs) { + ArrayRef coeffs) { if (empty) return OptimumKind::Empty; @@ -1468,7 +1466,7 @@ bool Simplex::isUnbounded() { if (empty) return false; - SmallVector dir(var.size() + 1); + SmallVector dir(var.size() + 1); for (unsigned i = 0; i < var.size(); ++i) { dir[i] = 1; @@ -1578,14 +1576,14 @@ std::optional> Simplex::getRationalSample() const { } else { // If the variable is in row position, its sample value is the // entry in the constant column divided by the denominator. - DynamicAPInt denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); sample.emplace_back(tableau(u.pos, 1), denom); } } return sample; } -void LexSimplexBase::addInequality(ArrayRef coeffs) { +void LexSimplexBase::addInequality(ArrayRef coeffs) { addRow(coeffs, /*makeRestricted=*/true); } @@ -1610,7 +1608,7 @@ MaybeOptimum> LexSimplex::getRationalSample() const { // If the variable is in row position, its sample value is the // entry in the constant column divided by the denominator. - DynamicAPInt denom = tableau(u.pos, 0); + MPInt denom = tableau(u.pos, 0); if (usingBigM) if (tableau(u.pos, 2) != denom) return OptimumKind::Unbounded; @@ -1619,15 +1617,14 @@ MaybeOptimum> LexSimplex::getRationalSample() const { return sample; } -std::optional> -Simplex::getSamplePointIfIntegral() const { +std::optional> Simplex::getSamplePointIfIntegral() const { // If the tableau is empty, no sample point exists. if (empty) return {}; // The value will always exist since the Simplex is non-empty. SmallVector rationalSample = *getRationalSample(); - SmallVector integerSample; + SmallVector integerSample; integerSample.reserve(var.size()); for (const Fraction &coord : rationalSample) { // If the sample is non-integral, return std::nullopt. @@ -1659,14 +1656,14 @@ class presburger::GBRSimplex { /// Add an equality dotProduct(dir, x - y) == 0. /// First pushes a snapshot for the current simplex state to the stack so /// that this can be rolled back later. - void addEqualityForDirection(ArrayRef dir) { - assert(llvm::any_of(dir, [](const DynamicAPInt &x) { return x != 0; }) && + void addEqualityForDirection(ArrayRef dir) { + assert(llvm::any_of(dir, [](const MPInt &x) { return x != 0; }) && "Direction passed is the zero vector!"); snapshotStack.push_back(simplex.getSnapshot()); simplex.addEquality(getCoeffsForDirection(dir)); } /// Compute max(dotProduct(dir, x - y)). - Fraction computeWidth(ArrayRef dir) { + Fraction computeWidth(ArrayRef dir) { MaybeOptimum maybeWidth = simplex.computeOptimum(Direction::Up, getCoeffsForDirection(dir)); assert(maybeWidth.isBounded() && "Width should be bounded!"); @@ -1675,9 +1672,9 @@ class presburger::GBRSimplex { /// Compute max(dotProduct(dir, x - y)) and save the dual variables for only /// the direction equalities to `dual`. - Fraction computeWidthAndDuals(ArrayRef dir, - SmallVectorImpl &dual, - DynamicAPInt &dualDenom) { + Fraction computeWidthAndDuals(ArrayRef dir, + SmallVectorImpl &dual, + MPInt &dualDenom) { // We can't just call into computeWidth or computeOptimum since we need to // access the state of the tableau after computing the optimum, and these // functions rollback the insertion of the objective function into the @@ -1745,13 +1742,12 @@ class presburger::GBRSimplex { /// i.e., dir_1 * x_1 + dir_2 * x_2 + ... + dir_n * x_n /// - dir_1 * y_1 - dir_2 * y_2 - ... - dir_n * y_n, /// where n is the dimension of the original polytope. - SmallVector - getCoeffsForDirection(ArrayRef dir) { + SmallVector getCoeffsForDirection(ArrayRef dir) { assert(2 * dir.size() == simplex.getNumVariables() && "Direction vector has wrong dimensionality"); - SmallVector coeffs(dir.begin(), dir.end()); + SmallVector coeffs(dir.begin(), dir.end()); coeffs.reserve(2 * dir.size()); - for (const DynamicAPInt &coeff : dir) + for (const MPInt &coeff : dir) coeffs.push_back(-coeff); coeffs.emplace_back(0); // constant term return coeffs; @@ -1828,8 +1824,8 @@ void Simplex::reduceBasis(IntMatrix &basis, unsigned level) { GBRSimplex gbrSimplex(*this); SmallVector width; - SmallVector dual; - DynamicAPInt dualDenom; + SmallVector dual; + MPInt dualDenom; // Finds the value of u that minimizes width_i(b_{i+1} + u*b_i), caches the // duals from this computation, sets b_{i+1} to b_{i+1} + u*b_i, and returns @@ -1852,11 +1848,11 @@ void Simplex::reduceBasis(IntMatrix &basis, unsigned level) { auto updateBasisWithUAndGetFCandidate = [&](unsigned i) -> Fraction { assert(i < level + dual.size() && "dual_i is not known!"); - DynamicAPInt u = floorDiv(dual[i - level], dualDenom); + MPInt u = floorDiv(dual[i - level], dualDenom); basis.addToRow(i, i + 1, u); if (dual[i - level] % dualDenom != 0) { - SmallVector candidateDual[2]; - DynamicAPInt candidateDualDenom[2]; + SmallVector candidateDual[2]; + MPInt candidateDualDenom[2]; Fraction widthI[2]; // Initially u is floor(dual) and basis reflects this. @@ -1883,12 +1879,12 @@ void Simplex::reduceBasis(IntMatrix &basis, unsigned level) { // Check the value at u - 1. assert(gbrSimplex.computeWidth(scaleAndAddForAssert( - basis.getRow(i + 1), DynamicAPInt(-1), basis.getRow(i))) >= + basis.getRow(i + 1), MPInt(-1), basis.getRow(i))) >= widthI[j] && "Computed u value does not minimize the width!"); // Check the value at u + 1. assert(gbrSimplex.computeWidth(scaleAndAddForAssert( - basis.getRow(i + 1), DynamicAPInt(+1), basis.getRow(i))) >= + basis.getRow(i + 1), MPInt(+1), basis.getRow(i))) >= widthI[j] && "Computed u value does not minimize the width!"); @@ -1989,7 +1985,7 @@ void Simplex::reduceBasis(IntMatrix &basis, unsigned level) { /// /// To avoid potentially arbitrarily large recursion depths leading to stack /// overflows, this algorithm is implemented iteratively. -std::optional> Simplex::findIntegerSample() { +std::optional> Simplex::findIntegerSample() { if (empty) return {}; @@ -2000,9 +1996,9 @@ std::optional> Simplex::findIntegerSample() { // The snapshot just before constraining a direction to a value at each level. SmallVector snapshotStack; // The maximum value in the range of the direction for each level. - SmallVector upperBoundStack; + SmallVector upperBoundStack; // The next value to try constraining the basis vector to at each level. - SmallVector nextValueStack; + SmallVector nextValueStack; snapshotStack.reserve(basis.getNumRows()); upperBoundStack.reserve(basis.getNumRows()); @@ -2022,7 +2018,7 @@ std::optional> Simplex::findIntegerSample() { // just come down a level ("recursed"). Find the lower and upper bounds. // If there is more than one integer point in the range, perform // generalized basis reduction. - SmallVector basisCoeffs = + SmallVector basisCoeffs = llvm::to_vector<8>(basis.getRow(level)); basisCoeffs.emplace_back(0); @@ -2074,7 +2070,7 @@ std::optional> Simplex::findIntegerSample() { // to the snapshot of the starting state at this level. (in the "recursed" // case this has no effect) rollback(snapshotStack.back()); - DynamicAPInt nextValue = nextValueStack.back(); + MPInt nextValue = nextValueStack.back(); ++nextValueStack.back(); if (nextValue > upperBoundStack.back()) { // We have exhausted the range and found no solution. Pop the stack and @@ -2087,8 +2083,8 @@ std::optional> Simplex::findIntegerSample() { } // Try the next value in the range and "recurse" into the next level. - SmallVector basisCoeffs(basis.getRow(level).begin(), - basis.getRow(level).end()); + SmallVector basisCoeffs(basis.getRow(level).begin(), + basis.getRow(level).end()); basisCoeffs.push_back(-nextValue); addEquality(basisCoeffs); level++; @@ -2099,16 +2095,16 @@ std::optional> Simplex::findIntegerSample() { /// Compute the minimum and maximum integer values the expression can take. We /// compute each separately. -std::pair, MaybeOptimum> -Simplex::computeIntegerBounds(ArrayRef coeffs) { - MaybeOptimum minRoundedUp( +std::pair, MaybeOptimum> +Simplex::computeIntegerBounds(ArrayRef coeffs) { + MaybeOptimum minRoundedUp( computeOptimum(Simplex::Direction::Down, coeffs).map(ceil)); - MaybeOptimum maxRoundedDown( + MaybeOptimum maxRoundedDown( computeOptimum(Simplex::Direction::Up, coeffs).map(floor)); return {minRoundedUp, maxRoundedDown}; } -bool Simplex::isFlatAlong(ArrayRef coeffs) { +bool Simplex::isFlatAlong(ArrayRef coeffs) { assert(!isEmpty() && "cannot check for flatness of empty simplex!"); auto upOpt = computeOptimum(Simplex::Direction::Up, coeffs); auto downOpt = computeOptimum(Simplex::Direction::Down, coeffs); @@ -2187,7 +2183,7 @@ bool Simplex::isRationalSubsetOf(const IntegerRelation &rel) { /// maximum satisfy it. Hence, it is a cut inequality. If both are < 0, no /// points of the polytope satisfy the inequality, which means it is a separate /// inequality. -Simplex::IneqType Simplex::findIneqType(ArrayRef coeffs) { +Simplex::IneqType Simplex::findIneqType(ArrayRef coeffs) { MaybeOptimum minimum = computeOptimum(Direction::Down, coeffs); if (minimum.isBounded() && *minimum >= Fraction(0, 1)) { return IneqType::Redundant; @@ -2202,7 +2198,7 @@ Simplex::IneqType Simplex::findIneqType(ArrayRef coeffs) { /// Checks whether the type of the inequality with coefficients `coeffs` /// is Redundant. -bool Simplex::isRedundantInequality(ArrayRef coeffs) { +bool Simplex::isRedundantInequality(ArrayRef coeffs) { assert(!empty && "It is not meaningful to ask about redundancy in an empty set!"); return findIneqType(coeffs) == IneqType::Redundant; @@ -2212,7 +2208,7 @@ bool Simplex::isRedundantInequality(ArrayRef coeffs) { /// the existing constraints. This is redundant when `coeffs` is already /// always zero under the existing constraints. `coeffs` is always zero /// when the minimum and maximum value that `coeffs` can take are both zero. -bool Simplex::isRedundantEquality(ArrayRef coeffs) { +bool Simplex::isRedundantEquality(ArrayRef coeffs) { assert(!empty && "It is not meaningful to ask about redundancy in an empty set!"); MaybeOptimum minimum = computeOptimum(Direction::Down, coeffs); diff --git a/mlir/lib/Analysis/Presburger/SlowMPInt.cpp b/mlir/lib/Analysis/Presburger/SlowMPInt.cpp new file mode 100644 index 0000000000000..ae6f2827be926 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/SlowMPInt.cpp @@ -0,0 +1,290 @@ +//===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include + +using namespace mlir; +using namespace presburger; +using namespace detail; + +SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {} +SlowMPInt::SlowMPInt() : SlowMPInt(0) {} +SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {} +SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); } +SlowMPInt::operator int64_t() const { return val.getSExtValue(); } + +llvm::hash_code detail::hash_value(const SlowMPInt &x) { + return hash_value(x.val); +} + +/// --------------------------------------------------------------------------- +/// Printing. +/// --------------------------------------------------------------------------- +void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; } + +void SlowMPInt::dump() const { print(llvm::errs()); } + +llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os, + const SlowMPInt &x) { + x.print(os); + return os; +} + +/// --------------------------------------------------------------------------- +/// Convenience operator overloads for int64_t. +/// --------------------------------------------------------------------------- +SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) { + return a += SlowMPInt(b); +} +SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) { + return a -= SlowMPInt(b); +} +SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) { + return a *= SlowMPInt(b); +} +SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) { + return a /= SlowMPInt(b); +} +SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) { + return a %= SlowMPInt(b); +} + +bool detail::operator==(const SlowMPInt &a, int64_t b) { + return a == SlowMPInt(b); +} +bool detail::operator!=(const SlowMPInt &a, int64_t b) { + return a != SlowMPInt(b); +} +bool detail::operator>(const SlowMPInt &a, int64_t b) { + return a > SlowMPInt(b); +} +bool detail::operator<(const SlowMPInt &a, int64_t b) { + return a < SlowMPInt(b); +} +bool detail::operator<=(const SlowMPInt &a, int64_t b) { + return a <= SlowMPInt(b); +} +bool detail::operator>=(const SlowMPInt &a, int64_t b) { + return a >= SlowMPInt(b); +} +SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) { + return a + SlowMPInt(b); +} +SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) { + return a - SlowMPInt(b); +} +SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) { + return a * SlowMPInt(b); +} +SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) { + return a / SlowMPInt(b); +} +SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) { + return a % SlowMPInt(b); +} + +bool detail::operator==(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) == b; +} +bool detail::operator!=(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) != b; +} +bool detail::operator>(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) > b; +} +bool detail::operator<(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) < b; +} +bool detail::operator<=(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) <= b; +} +bool detail::operator>=(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) >= b; +} +SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) + b; +} +SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) - b; +} +SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) * b; +} +SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) / b; +} +SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) { + return SlowMPInt(a) % b; +} + +static unsigned getMaxWidth(const APInt &a, const APInt &b) { + return std::max(a.getBitWidth(), b.getBitWidth()); +} + +/// --------------------------------------------------------------------------- +/// Comparison operators. +/// --------------------------------------------------------------------------- + +// TODO: consider instead making APInt::compare available and using that. +bool SlowMPInt::operator==(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width) == o.val.sext(width); +} +bool SlowMPInt::operator!=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width) != o.val.sext(width); +} +bool SlowMPInt::operator>(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sgt(o.val.sext(width)); +} +bool SlowMPInt::operator<(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).slt(o.val.sext(width)); +} +bool SlowMPInt::operator<=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sle(o.val.sext(width)); +} +bool SlowMPInt::operator>=(const SlowMPInt &o) const { + unsigned width = getMaxWidth(val, o.val); + return val.sext(width).sge(o.val.sext(width)); +} + +/// --------------------------------------------------------------------------- +/// Arithmetic operators. +/// --------------------------------------------------------------------------- + +/// Bring a and b to have the same width and then call op(a, b, overflow). +/// If the overflow bit becomes set, resize a and b to double the width and +/// call op(a, b, overflow), returning its result. The operation with double +/// widths should not also overflow. +APInt runOpWithExpandOnOverflow( + const APInt &a, const APInt &b, + llvm::function_ref + op) { + bool overflow; + unsigned width = getMaxWidth(a, b); + APInt ret = op(a.sext(width), b.sext(width), overflow); + if (!overflow) + return ret; + + width *= 2; + ret = op(a.sext(width), b.sext(width), overflow); + assert(!overflow && "double width should be sufficient to avoid overflow!"); + return ret; +} + +SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov))); +} +SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov))); +} +SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov))); +} +SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const { + return SlowMPInt( + runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov))); +} +SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; } +SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { + if (rhs == -1) + return -lhs; + unsigned width = getMaxWidth(lhs.val, rhs.val); + return SlowMPInt(llvm::APIntOps::RoundingSDiv( + lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::UP)); +} +SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { + if (rhs == -1) + return -lhs; + unsigned width = getMaxWidth(lhs.val, rhs.val); + return SlowMPInt(llvm::APIntOps::RoundingSDiv( + lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::DOWN)); +} +// The RHS is always expected to be positive, and the result +/// is always non-negative. +SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) { + assert(rhs >= 1 && "mod is only supported for positive divisors!"); + return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; +} + +SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) { + assert(a >= 0 && b >= 0 && "operands must be non-negative!"); + unsigned width = getMaxWidth(a.val, b.val); + return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(a.val.sext(width), + b.val.sext(width))); +} + +/// Returns the least common multiple of 'a' and 'b'. +SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) { + SlowMPInt x = abs(a); + SlowMPInt y = abs(b); + return (x * y) / gcd(x, y); +} + +/// This operation cannot overflow. +SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const { + unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); + return SlowMPInt(val.sext(width).srem(o.val.sext(width))); +} + +SlowMPInt SlowMPInt::operator-() const { + if (val.isMinSignedValue()) { + /// Overflow only occurs when the value is the minimum possible value. + APInt ret = val.sext(2 * val.getBitWidth()); + return SlowMPInt(-ret); + } + return SlowMPInt(-val); +} + +/// --------------------------------------------------------------------------- +/// Assignment operators, preincrement, predecrement. +/// --------------------------------------------------------------------------- +SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) { + *this = *this + o; + return *this; +} +SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) { + *this = *this - o; + return *this; +} +SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) { + *this = *this * o; + return *this; +} +SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) { + *this = *this / o; + return *this; +} +SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) { + *this = *this % o; + return *this; +} +SlowMPInt &SlowMPInt::operator++() { + *this += 1; + return *this; +} + +SlowMPInt &SlowMPInt::operator--() { + *this -= 1; + return *this; +} diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp index 1fab4c4dcca33..f717a4de5d728 100644 --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -12,6 +12,7 @@ #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/MPInt.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -30,20 +31,19 @@ using namespace mlir; using namespace presburger; -using llvm::dynamicAPIntFromInt64; /// Normalize a division's `dividend` and the `divisor` by their GCD. For /// example: if the dividend and divisor are [2,0,4] and 4 respectively, /// they get normalized to [1,0,2] and 2. The divisor must be non-negative; /// it is allowed for the divisor to be zero, but nothing is done in this case. -static void normalizeDivisionByGCD(MutableArrayRef dividend, - DynamicAPInt &divisor) { +static void normalizeDivisionByGCD(MutableArrayRef dividend, + MPInt &divisor) { assert(divisor > 0 && "divisor must be non-negative!"); if (divisor == 0 || dividend.empty()) return; // We take the absolute value of dividend's coefficients to make sure that // `gcd` is positive. - DynamicAPInt gcd = llvm::gcd(abs(dividend.front()), divisor); + MPInt gcd = presburger::gcd(abs(dividend.front()), divisor); // The reason for ignoring the constant term is as follows. // For a division: @@ -53,14 +53,14 @@ static void normalizeDivisionByGCD(MutableArrayRef dividend, // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not // influence the result of the floor division and thus, can be ignored. for (size_t i = 1, m = dividend.size() - 1; i < m; i++) { - gcd = llvm::gcd(abs(dividend[i]), gcd); + gcd = presburger::gcd(abs(dividend[i]), gcd); if (gcd == 1) return; } // Normalize the dividend and the denominator. std::transform(dividend.begin(), dividend.end(), dividend.begin(), - [gcd](DynamicAPInt &n) { return floorDiv(n, gcd); }); + [gcd](MPInt &n) { return floorDiv(n, gcd); }); divisor /= gcd; } @@ -104,8 +104,7 @@ static void normalizeDivisionByGCD(MutableArrayRef dividend, /// The final division expression is normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned ubIneq, unsigned lbIneq, - MutableArrayRef expr, - DynamicAPInt &divisor) { + MutableArrayRef expr, MPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(ubIneq <= cst.getNumInequalities() && @@ -132,9 +131,9 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, // Then, check if the constant term is of the proper form. // Due to the form of the upper/lower bound inequalities, the sum of their // constants is `divisor - 1 - c`. From this, we can extract c: - DynamicAPInt constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + - cst.atIneq(ubIneq, cst.getNumCols() - 1); - DynamicAPInt c = divisor - 1 - constantSum; + MPInt constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) + + cst.atIneq(ubIneq, cst.getNumCols() - 1); + MPInt c = divisor - 1 - constantSum; // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. // This also implictly checks that `divisor` is positive. @@ -169,9 +168,8 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, /// set to the denominator of the division. The final division expression is /// normalized by GCD. static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, - unsigned eqInd, - MutableArrayRef expr, - DynamicAPInt &divisor) { + unsigned eqInd, MutableArrayRef expr, + MPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(eqInd <= cst.getNumEqualities() && "Invalid equality position"); @@ -180,7 +178,7 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, // Extract divisor, the divisor can be negative and hence its sign information // is stored in `signDiv` to reverse the sign of dividend's coefficients. // Equality must involve the pos-th variable and hence `tempDiv` != 0. - DynamicAPInt tempDiv = cst.atEq(eqInd, pos); + MPInt tempDiv = cst.atEq(eqInd, pos); if (tempDiv == 0) return failure(); int signDiv = tempDiv < 0 ? -1 : 1; @@ -202,7 +200,7 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, // explicit representation has not been found yet, otherwise returns `true`. static bool checkExplicitRepresentation(const IntegerRelation &cst, ArrayRef foundRepr, - ArrayRef dividend, + ArrayRef dividend, unsigned pos) { // Exit to avoid circular dependencies between divisions. for (unsigned c = 0, e = cst.getNumVars(); c < e; ++c) { @@ -231,9 +229,11 @@ static bool checkExplicitRepresentation(const IntegerRelation &cst, /// the representation could be computed, `dividend` and `denominator` are set. /// If the representation could not be computed, the kind attribute in /// `MaybeLocalRepr` is set to None. -MaybeLocalRepr presburger::computeSingleVarRepr( - const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, - MutableArrayRef dividend, DynamicAPInt &divisor) { +MaybeLocalRepr presburger::computeSingleVarRepr(const IntegerRelation &cst, + ArrayRef foundRepr, + unsigned pos, + MutableArrayRef dividend, + MPInt &divisor) { assert(pos < cst.getNumVars() && "invalid position"); assert(foundRepr.size() == cst.getNumVars() && "Size of foundRepr does not match total number of variables"); @@ -275,12 +275,12 @@ MaybeLocalRepr presburger::computeSingleVarRepr( MaybeLocalRepr presburger::computeSingleVarRepr( const IntegerRelation &cst, ArrayRef foundRepr, unsigned pos, SmallVector ÷nd, unsigned &divisor) { - SmallVector dividendDynamicAPInt(cst.getNumCols()); - DynamicAPInt divisorDynamicAPInt; - MaybeLocalRepr result = computeSingleVarRepr( - cst, foundRepr, pos, dividendDynamicAPInt, divisorDynamicAPInt); - dividend = getInt64Vec(dividendDynamicAPInt); - divisor = unsigned(int64_t(divisorDynamicAPInt)); + SmallVector dividendMPInt(cst.getNumCols()); + MPInt divisorMPInt; + MaybeLocalRepr result = + computeSingleVarRepr(cst, foundRepr, pos, dividendMPInt, divisorMPInt); + dividend = getInt64Vec(dividendMPInt); + divisor = unsigned(int64_t(divisorMPInt)); return result; } @@ -318,86 +318,80 @@ void presburger::mergeLocalVars( divsA.removeDuplicateDivs(merge); } -SmallVector -presburger::getDivUpperBound(ArrayRef dividend, - const DynamicAPInt &divisor, - unsigned localVarIdx) { +SmallVector presburger::getDivUpperBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx) { assert(divisor > 0 && "divisor must be positive!"); assert(dividend[localVarIdx] == 0 && "Local to be set to division must have zero coeff!"); - SmallVector ineq(dividend.begin(), dividend.end()); + SmallVector ineq(dividend.begin(), dividend.end()); ineq[localVarIdx] = -divisor; return ineq; } -SmallVector -presburger::getDivLowerBound(ArrayRef dividend, - const DynamicAPInt &divisor, - unsigned localVarIdx) { +SmallVector presburger::getDivLowerBound(ArrayRef dividend, + const MPInt &divisor, + unsigned localVarIdx) { assert(divisor > 0 && "divisor must be positive!"); assert(dividend[localVarIdx] == 0 && "Local to be set to division must have zero coeff!"); - SmallVector ineq(dividend.size()); + SmallVector ineq(dividend.size()); std::transform(dividend.begin(), dividend.end(), ineq.begin(), - std::negate()); + std::negate()); ineq[localVarIdx] = divisor; ineq.back() += divisor - 1; return ineq; } -DynamicAPInt presburger::gcdRange(ArrayRef range) { - DynamicAPInt gcd(0); - for (const DynamicAPInt &elem : range) { - gcd = llvm::gcd(gcd, abs(elem)); +MPInt presburger::gcdRange(ArrayRef range) { + MPInt gcd(0); + for (const MPInt &elem : range) { + gcd = presburger::gcd(gcd, abs(elem)); if (gcd == 1) return gcd; } return gcd; } -DynamicAPInt presburger::normalizeRange(MutableArrayRef range) { - DynamicAPInt gcd = gcdRange(range); +MPInt presburger::normalizeRange(MutableArrayRef range) { + MPInt gcd = gcdRange(range); if ((gcd == 0) || (gcd == 1)) return gcd; - for (DynamicAPInt &elem : range) + for (MPInt &elem : range) elem /= gcd; return gcd; } -void presburger::normalizeDiv(MutableArrayRef num, - DynamicAPInt &denom) { +void presburger::normalizeDiv(MutableArrayRef num, MPInt &denom) { assert(denom > 0 && "denom must be positive!"); - DynamicAPInt gcd = llvm::gcd(gcdRange(num), denom); - for (DynamicAPInt &coeff : num) + MPInt gcd = presburger::gcd(gcdRange(num), denom); + for (MPInt &coeff : num) coeff /= gcd; denom /= gcd; } -SmallVector -presburger::getNegatedCoeffs(ArrayRef coeffs) { - SmallVector negatedCoeffs; +SmallVector presburger::getNegatedCoeffs(ArrayRef coeffs) { + SmallVector negatedCoeffs; negatedCoeffs.reserve(coeffs.size()); - for (const DynamicAPInt &coeff : coeffs) + for (const MPInt &coeff : coeffs) negatedCoeffs.emplace_back(-coeff); return negatedCoeffs; } -SmallVector -presburger::getComplementIneq(ArrayRef ineq) { - SmallVector coeffs; +SmallVector presburger::getComplementIneq(ArrayRef ineq) { + SmallVector coeffs; coeffs.reserve(ineq.size()); - for (const DynamicAPInt &coeff : ineq) + for (const MPInt &coeff : ineq) coeffs.emplace_back(-coeff); --coeffs.back(); return coeffs; } -SmallVector, 4> -DivisionRepr::divValuesAt(ArrayRef point) const { +SmallVector, 4> +DivisionRepr::divValuesAt(ArrayRef point) const { assert(point.size() == getNumNonDivs() && "Incorrect point size"); - SmallVector, 4> divValues(getNumDivs(), - std::nullopt); + SmallVector, 4> divValues(getNumDivs(), std::nullopt); bool changed = true; while (changed) { changed = false; @@ -406,8 +400,8 @@ DivisionRepr::divValuesAt(ArrayRef point) const { if (divValues[i]) continue; - ArrayRef dividend = getDividend(i); - DynamicAPInt divVal(0); + ArrayRef dividend = getDividend(i); + MPInt divVal(0); // Check if we have all the division values required for this division. unsigned j, f; @@ -496,8 +490,8 @@ void DivisionRepr::normalizeDivs() { } } -void DivisionRepr::insertDiv(unsigned pos, ArrayRef dividend, - const DynamicAPInt &divisor) { +void DivisionRepr::insertDiv(unsigned pos, ArrayRef dividend, + const MPInt &divisor) { assert(pos <= getNumDivs() && "Invalid insertion position"); assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size"); @@ -510,32 +504,29 @@ void DivisionRepr::insertDiv(unsigned pos, unsigned num) { assert(pos <= getNumDivs() && "Invalid insertion position"); dividends.insertColumns(getDivOffset() + pos, num); dividends.insertRows(pos, num); - denoms.insert(denoms.begin() + pos, num, DynamicAPInt(0)); + denoms.insert(denoms.begin() + pos, num, MPInt(0)); } void DivisionRepr::print(raw_ostream &os) const { os << "Dividends:\n"; dividends.print(os); os << "Denominators\n"; - for (const DynamicAPInt &denom : denoms) + for (const MPInt &denom : denoms) os << denom << " "; os << "\n"; } void DivisionRepr::dump() const { print(llvm::errs()); } -SmallVector -presburger::getDynamicAPIntVec(ArrayRef range) { - SmallVector result(range.size()); - std::transform(range.begin(), range.end(), result.begin(), - dynamicAPIntFromInt64); +SmallVector presburger::getMPIntVec(ArrayRef range) { + SmallVector result(range.size()); + std::transform(range.begin(), range.end(), result.begin(), mpintFromInt64); return result; } -SmallVector presburger::getInt64Vec(ArrayRef range) { +SmallVector presburger::getInt64Vec(ArrayRef range) { SmallVector result(range.size()); - std::transform(range.begin(), range.end(), result.begin(), - int64fromDynamicAPInt); + std::transform(range.begin(), range.end(), result.begin(), int64FromMPInt); return result; } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt index b69f514711337..c98668f63fa5d 100644 --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_unittest(MLIRPresburgerTests IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp + MPIntTest.cpp Parser.h ParserTest.cpp PresburgerSetTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp index c9fad953dacd5..5fee9de1994c8 100644 --- a/mlir/unittests/Analysis/Presburger/FractionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp @@ -8,7 +8,7 @@ using namespace presburger; TEST(FractionTest, getAsInteger) { Fraction f(3, 1); - EXPECT_EQ(f.getAsInteger(), DynamicAPInt(3)); + EXPECT_EQ(f.getAsInteger(), MPInt(3)); } TEST(FractionTest, nearIntegers) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp index f64bb240b4ee4..ba035e84ff1fd 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -41,8 +41,8 @@ makeSetFromConstraints(unsigned ids, ArrayRef> ineqs, return set; } -static void dump(ArrayRef vec) { - for (const DynamicAPInt &x : vec) +static void dump(ArrayRef vec) { + for (const MPInt &x : vec) llvm::errs() << x << ' '; llvm::errs() << '\n'; } @@ -60,8 +60,8 @@ static void dump(ArrayRef vec) { /// opposite of hasSample. static void checkSample(bool hasSample, const IntegerPolyhedron &poly, TestFunction fn = TestFunction::Sample) { - std::optional> maybeSample; - MaybeOptimum> maybeLexMin; + std::optional> maybeSample; + MaybeOptimum> maybeLexMin; switch (fn) { case TestFunction::Sample: maybeSample = poly.findIntegerSample(); @@ -585,12 +585,10 @@ TEST(IntegerPolyhedronTest, removeRedundantConstraintsTest) { // y >= 128x >= 0. poly5.removeRedundantConstraints(); EXPECT_EQ(poly5.getNumInequalities(), 3u); - SmallVector redundantConstraint = - getDynamicAPIntVec({0, 1, 0}); + SmallVector redundantConstraint = getMPIntVec({0, 1, 0}); for (unsigned i = 0; i < 3; ++i) { // Ensure that the removed constraint was the redundant constraint [3]. - EXPECT_NE(poly5.getInequality(i), - ArrayRef(redundantConstraint)); + EXPECT_NE(poly5.getInequality(i), ArrayRef(redundantConstraint)); } } @@ -633,7 +631,7 @@ static void checkDivisionRepresentation( DivisionRepr divs = poly.getLocalReprs(); // Check that the `denominators` and `expectedDenominators` match. - EXPECT_EQ(ArrayRef(getDynamicAPIntVec(expectedDenominators)), + EXPECT_EQ(ArrayRef(getMPIntVec(expectedDenominators)), divs.getDenoms()); // Check that the `dividends` and `expectedDividends` match. If the @@ -1168,9 +1166,9 @@ TEST(IntegerPolyhedronTest, findRationalLexMin) { } void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef min) { - MaybeOptimum> lexMin = poly.findIntegerLexMin(); + MaybeOptimum> lexMin = poly.findIntegerLexMin(); ASSERT_TRUE(lexMin.isBounded()); - EXPECT_EQ(*lexMin, getDynamicAPIntVec(min)); + EXPECT_EQ(*lexMin, getMPIntVec(min)); } void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) { @@ -1465,7 +1463,7 @@ TEST(IntegerPolyhedronTest, computeVolume) { bool containsPointNoLocal(const IntegerPolyhedron &poly, ArrayRef point) { - return poly.containsPointNoLocal(getDynamicAPIntVec(point)).has_value(); + return poly.containsPointNoLocal(getMPIntVec(point)).has_value(); } TEST(IntegerPolyhedronTest, containsPointNoLocal) { diff --git a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp index 388ac1174dcdc..721d1fd6e2535 100644 --- a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp +++ b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp @@ -23,8 +23,7 @@ void testColumnEchelonForm(const IntMatrix &m, unsigned expectedRank) { // In column echelon form, each row's last non-zero value can be at most one // column to the right of the last non-zero column among the previous rows. for (unsigned row = 0, nRows = m.getNumRows(); row < nRows; ++row) { - SmallVector rowVec = - transform.preMultiplyWithRow(m.getRow(row)); + SmallVector rowVec = transform.preMultiplyWithRow(m.getRow(row)); for (unsigned col = lastAllowedNonZeroCol + 1, nCols = m.getNumColumns(); col < nCols; ++col) { EXPECT_EQ(rowVec[col], 0); diff --git a/mlir/unittests/Analysis/Presburger/MPIntTest.cpp b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp new file mode 100644 index 0000000000000..3c145d39352c3 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/MPIntTest.cpp @@ -0,0 +1,200 @@ +//===- MPIntTest.cpp - Tests for MPInt ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Presburger/MPInt.h" +#include "mlir/Analysis/Presburger/SlowMPInt.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +// googletest boilerplate to run the same tests with both MPInt and SlowMPInt. +template +class IntTest : public testing::Test {}; +using TypeList = testing::Types; +// This is for pretty-printing the test name with the name of the class in use. +class TypeNames { +public: + template + static std::string GetName(int) { // NOLINT; gtest mandates this name. + if (std::is_same()) + return "MPInt"; + if (std::is_same()) + return "SlowMPInt"; + llvm_unreachable("Unknown class!"); + } +}; +TYPED_TEST_SUITE(IntTest, TypeList, TypeNames); + +TYPED_TEST(IntTest, ops) { + TypeParam two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + five, ten); + EXPECT_EQ(five * five, 2 * ten + five); + EXPECT_EQ(five * five, 3 * ten - five); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(five / two, two); + EXPECT_EQ(five % two, two / two); + + EXPECT_EQ(-ten % seven, -10 % 7); + EXPECT_EQ(ten % -seven, 10 % -7); + EXPECT_EQ(-ten % -seven, -10 % -7); + EXPECT_EQ(ten % seven, 10 % 7); + + EXPECT_EQ(-ten / seven, -10 / 7); + EXPECT_EQ(ten / -seven, 10 / -7); + EXPECT_EQ(-ten / -seven, -10 / -7); + EXPECT_EQ(ten / seven, 10 / 7); + + TypeParam x = ten; + x += five; + EXPECT_EQ(x, 15); + x *= two; + EXPECT_EQ(x, 30); + x /= seven; + EXPECT_EQ(x, 4); + x -= two * 10; + EXPECT_EQ(x, -16); + x *= 2 * two; + EXPECT_EQ(x, -64); + x /= two / -2; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, ten); + EXPECT_GE(ten, ten); + EXPECT_EQ(ten, ten); + EXPECT_FALSE(ten != ten); + EXPECT_FALSE(ten < ten); + EXPECT_FALSE(ten > ten); + EXPECT_LT(five, ten); + EXPECT_GT(ten, five); +} + +TYPED_TEST(IntTest, ops64Overloads) { + TypeParam two(2), five(5), seven(7), ten(10); + EXPECT_EQ(five + 5, ten); + EXPECT_EQ(five + 5, 5 + five); + EXPECT_EQ(five * 5, 2 * ten + 5); + EXPECT_EQ(five * 5, 3 * ten - 5); + EXPECT_EQ(five * two, ten); + EXPECT_EQ(5 / two, 2); + EXPECT_EQ(five / 2, 2); + EXPECT_EQ(2 % two, 0); + EXPECT_EQ(2 - two, 0); + EXPECT_EQ(2 % two, two % 2); + + TypeParam x = ten; + x += 5; + EXPECT_EQ(x, 15); + x *= 2; + EXPECT_EQ(x, 30); + x /= 7; + EXPECT_EQ(x, 4); + x -= 20; + EXPECT_EQ(x, -16); + x *= 4; + EXPECT_EQ(x, -64); + x /= -1; + EXPECT_EQ(x, 64); + + EXPECT_LE(ten, 10); + EXPECT_GE(ten, 10); + EXPECT_EQ(ten, 10); + EXPECT_FALSE(ten != 10); + EXPECT_FALSE(ten < 10); + EXPECT_FALSE(ten > 10); + EXPECT_LT(five, 10); + EXPECT_GT(ten, 5); + + EXPECT_LE(10, ten); + EXPECT_GE(10, ten); + EXPECT_EQ(10, ten); + EXPECT_FALSE(10 != ten); + EXPECT_FALSE(10 < ten); + EXPECT_FALSE(10 > ten); + EXPECT_LT(5, ten); + EXPECT_GT(10, five); +} + +TYPED_TEST(IntTest, overflows) { + TypeParam x(1ll << 60); + EXPECT_EQ((x * x - x * x * x * x) / (x * x * x), 1 - (1ll << 60)); + TypeParam y(1ll << 62); + EXPECT_EQ((y + y + y + y + y + y) / y, 6); + EXPECT_EQ(-(2 * (-y)), 2 * y); // -(-2^63) overflow. + x *= x; + EXPECT_EQ(x, (y * y) / 16); + y += y; + y += y; + y += y; + y /= 8; + EXPECT_EQ(y, 1ll << 62); + + TypeParam min(std::numeric_limits::min()); + TypeParam one(1); + EXPECT_EQ(floorDiv(min, -one), -min); + EXPECT_EQ(ceilDiv(min, -one), -min); + EXPECT_EQ(abs(min), -min); + + TypeParam z = min; + z /= -1; + EXPECT_EQ(z, -min); + TypeParam w(min); + --w; + EXPECT_EQ(w, TypeParam(min) - 1); + TypeParam u(min); + u -= 1; + EXPECT_EQ(u, w); + + TypeParam max(std::numeric_limits::max()); + TypeParam v = max; + ++v; + EXPECT_EQ(v, max + 1); + TypeParam t = max; + t += 1; + EXPECT_EQ(t, v); +} + +TYPED_TEST(IntTest, floorCeilModAbsLcmGcd) { + TypeParam x(1ll << 50), one(1), two(2), three(3); + + // Run on small values and large values. + for (const TypeParam &y : {x, x * x}) { + EXPECT_EQ(floorDiv(3 * y, three), y); + EXPECT_EQ(ceilDiv(3 * y, three), y); + EXPECT_EQ(floorDiv(3 * y - 1, three), y - 1); + EXPECT_EQ(ceilDiv(3 * y - 1, three), y); + EXPECT_EQ(floorDiv(3 * y - 2, three), y - 1); + EXPECT_EQ(ceilDiv(3 * y - 2, three), y); + + EXPECT_EQ(mod(3 * y, three), 0); + EXPECT_EQ(mod(3 * y + 1, three), one); + EXPECT_EQ(mod(3 * y + 2, three), two); + + EXPECT_EQ(floorDiv(3 * y, y), 3); + EXPECT_EQ(ceilDiv(3 * y, y), 3); + EXPECT_EQ(floorDiv(3 * y - 1, y), 2); + EXPECT_EQ(ceilDiv(3 * y - 1, y), 3); + EXPECT_EQ(floorDiv(3 * y - 2, y), 2); + EXPECT_EQ(ceilDiv(3 * y - 2, y), 3); + + EXPECT_EQ(mod(3 * y, y), 0); + EXPECT_EQ(mod(3 * y + 1, y), 1); + EXPECT_EQ(mod(3 * y + 2, y), 2); + + EXPECT_EQ(abs(y), y); + EXPECT_EQ(abs(-y), y); + + EXPECT_EQ(gcd(3 * y, three), three); + EXPECT_EQ(lcm(y, three), 3 * y); + EXPECT_EQ(gcd(2 * y, 3 * y), y); + EXPECT_EQ(lcm(2 * y, 3 * y), 6 * y); + EXPECT_EQ(gcd(15 * y, 6 * y), 3 * y); + EXPECT_EQ(lcm(15 * y, 6 * y), 30 * y); + } +} diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp index cb8df8b346011..fa1f32970b146 100644 --- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp +++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp @@ -317,7 +317,7 @@ TEST(MatrixTest, intInverse) { mat = makeIntMatrix(2, 2, {{0, 0}, {1, 2}}); - DynamicAPInt det = mat.determinant(&inv); + MPInt det = mat.determinant(&inv); EXPECT_EQ(det, 0); } diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp index 63d0243808555..2f4fa27138914 100644 --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -21,26 +21,26 @@ using namespace presburger; /// Convenience functions to pass literals to Simplex. void addInequality(SimplexBase &simplex, ArrayRef coeffs) { - simplex.addInequality(getDynamicAPIntVec(coeffs)); + simplex.addInequality(getMPIntVec(coeffs)); } void addEquality(SimplexBase &simplex, ArrayRef coeffs) { - simplex.addEquality(getDynamicAPIntVec(coeffs)); + simplex.addEquality(getMPIntVec(coeffs)); } bool isRedundantInequality(Simplex &simplex, ArrayRef coeffs) { - return simplex.isRedundantInequality(getDynamicAPIntVec(coeffs)); + return simplex.isRedundantInequality(getMPIntVec(coeffs)); } bool isRedundantInequality(LexSimplex &simplex, ArrayRef coeffs) { - return simplex.isRedundantInequality(getDynamicAPIntVec(coeffs)); + return simplex.isRedundantInequality(getMPIntVec(coeffs)); } bool isRedundantEquality(Simplex &simplex, ArrayRef coeffs) { - return simplex.isRedundantEquality(getDynamicAPIntVec(coeffs)); + return simplex.isRedundantEquality(getMPIntVec(coeffs)); } bool isSeparateInequality(LexSimplex &simplex, ArrayRef coeffs) { - return simplex.isSeparateInequality(getDynamicAPIntVec(coeffs)); + return simplex.isSeparateInequality(getMPIntVec(coeffs)); } Simplex::IneqType findIneqType(Simplex &simplex, ArrayRef coeffs) { - return simplex.findIneqType(getDynamicAPIntVec(coeffs)); + return simplex.findIneqType(getMPIntVec(coeffs)); } /// Take a snapshot, add constraints making the set empty, and rollback. @@ -433,8 +433,8 @@ TEST(SimplexTest, pivotRedundantRegressionTest) { // After the rollback, the only remaining constraint is x <= -1. // The maximum value of x should be -1. simplex.rollback(snapshot); - MaybeOptimum maxX = simplex.computeOptimum( - Simplex::Direction::Up, getDynamicAPIntVec({1, 0, 0})); + MaybeOptimum maxX = + simplex.computeOptimum(Simplex::Direction::Up, getMPIntVec({1, 0, 0})); EXPECT_TRUE(maxX.isBounded() && *maxX == Fraction(-1, 1)); } @@ -467,9 +467,9 @@ TEST(SimplexTest, appendVariable) { EXPECT_EQ(simplex.getNumVariables(), 2u); EXPECT_EQ(simplex.getNumConstraints(), 2u); - EXPECT_EQ(simplex.computeIntegerBounds(getDynamicAPIntVec({0, 1, 0})), - std::make_pair(MaybeOptimum(DynamicAPInt(yMin)), - MaybeOptimum(DynamicAPInt(yMax)))); + EXPECT_EQ(simplex.computeIntegerBounds(getMPIntVec({0, 1, 0})), + std::make_pair(MaybeOptimum(MPInt(yMin)), + MaybeOptimum(MPInt(yMax)))); simplex.rollback(snapshot1); EXPECT_EQ(simplex.getNumVariables(), 1u); @@ -569,11 +569,10 @@ TEST(SimplexTest, IsRationalSubsetOf) { TEST(SimplexTest, addDivisionVariable) { Simplex simplex(/*nVar=*/1); - simplex.addDivisionVariable(getDynamicAPIntVec({1, 0}), DynamicAPInt(2)); + simplex.addDivisionVariable(getMPIntVec({1, 0}), MPInt(2)); addInequality(simplex, {1, 0, -3}); // x >= 3. addInequality(simplex, {-1, 0, 9}); // x <= 9. - std::optional> sample = - simplex.findIntegerSample(); + std::optional> sample = simplex.findIntegerSample(); ASSERT_TRUE(sample.has_value()); EXPECT_EQ((*sample)[0] / 2, (*sample)[1]); } diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h index ef4429b5c6bc8..6b00898a7e274 100644 --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -28,7 +28,6 @@ namespace mlir { namespace presburger { -using llvm::dynamicAPIntFromInt64; inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns, ArrayRef> matrix) { @@ -38,7 +37,7 @@ inline IntMatrix makeIntMatrix(unsigned numRow, unsigned numColumns, assert(matrix[i].size() == numColumns && "Output expression has incorrect dimensionality!"); for (unsigned j = 0; j < numColumns; ++j) - results(i, j) = DynamicAPInt(matrix[i][j]); + results(i, j) = MPInt(matrix[i][j]); } return results; } @@ -131,8 +130,8 @@ inline void EXPECT_EQ_REPR_QUASIPOLYNOMIAL(QuasiPolynomial a, /// lhs and rhs represent non-negative integers or positive infinity. The /// infinity case corresponds to when the Optional is empty. -inline bool infinityOrUInt64LE(std::optional lhs, - std::optional rhs) { +inline bool infinityOrUInt64LE(std::optional lhs, + std::optional rhs) { // No constraint. if (!rhs) return true; @@ -146,9 +145,9 @@ inline bool infinityOrUInt64LE(std::optional lhs, /// the true volume `trueVolume`, while also being at least as good an /// approximation as `resultBound`. inline void expectComputedVolumeIsValidOverapprox( - const std::optional &computedVolume, - const std::optional &trueVolume, - const std::optional &resultBound) { + const std::optional &computedVolume, + const std::optional &trueVolume, + const std::optional &resultBound) { assert(infinityOrUInt64LE(trueVolume, resultBound) && "can't expect result to be less than the true volume"); EXPECT_TRUE(infinityOrUInt64LE(trueVolume, computedVolume)); @@ -156,12 +155,11 @@ inline void expectComputedVolumeIsValidOverapprox( } inline void expectComputedVolumeIsValidOverapprox( - const std::optional &computedVolume, + const std::optional &computedVolume, std::optional trueVolume, std::optional resultBound) { expectComputedVolumeIsValidOverapprox( - computedVolume, - llvm::transformOptional(trueVolume, dynamicAPIntFromInt64), - llvm::transformOptional(resultBound, dynamicAPIntFromInt64)); + computedVolume, llvm::transformOptional(trueVolume, mpintFromInt64), + llvm::transformOptional(resultBound, mpintFromInt64)); } } // namespace presburger diff --git a/mlir/unittests/Analysis/Presburger/UtilsTest.cpp b/mlir/unittests/Analysis/Presburger/UtilsTest.cpp index d91a0f2da9cee..f09a1a760ce60 100644 --- a/mlir/unittests/Analysis/Presburger/UtilsTest.cpp +++ b/mlir/unittests/Analysis/Presburger/UtilsTest.cpp @@ -14,10 +14,9 @@ using namespace mlir; using namespace presburger; -static DivisionRepr -parseDivisionRepr(unsigned numVars, unsigned numDivs, - ArrayRef> dividends, - ArrayRef divisors) { +static DivisionRepr parseDivisionRepr(unsigned numVars, unsigned numDivs, + ArrayRef> dividends, + ArrayRef divisors) { DivisionRepr repr(numVars, numDivs); for (unsigned i = 0, rows = dividends.size(); i < rows; ++i) repr.setDiv(i, dividends[i], divisors[i]); @@ -38,15 +37,12 @@ static void checkEqual(DivisionRepr &a, DivisionRepr &b) { TEST(UtilsTest, ParseAndCompareDivisionReprTest) { auto merge = [](unsigned i, unsigned j) -> bool { return true; }; - DivisionRepr a = parseDivisionRepr(1, 1, {{DynamicAPInt(1), DynamicAPInt(2)}}, - {DynamicAPInt(2)}), - b = parseDivisionRepr(1, 1, {{DynamicAPInt(1), DynamicAPInt(2)}}, - {DynamicAPInt(2)}), - c = parseDivisionRepr( - 2, 2, - {{DynamicAPInt(0), DynamicAPInt(1), DynamicAPInt(2)}, - {DynamicAPInt(0), DynamicAPInt(1), DynamicAPInt(2)}}, - {DynamicAPInt(2), DynamicAPInt(2)}); + DivisionRepr a = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + b = parseDivisionRepr(1, 1, {{MPInt(1), MPInt(2)}}, {MPInt(2)}), + c = parseDivisionRepr(2, 2, + {{MPInt(0), MPInt(1), MPInt(2)}, + {MPInt(0), MPInt(1), MPInt(2)}}, + {MPInt(2), MPInt(2)}); c.removeDuplicateDivs(merge); checkEqual(a, b); checkEqual(a, c); @@ -54,21 +50,16 @@ TEST(UtilsTest, ParseAndCompareDivisionReprTest) { TEST(UtilsTest, DivisionReprNormalizeTest) { auto merge = [](unsigned i, unsigned j) -> bool { return true; }; - DivisionRepr a = parseDivisionRepr( - 2, 1, {{DynamicAPInt(1), DynamicAPInt(2), DynamicAPInt(-1)}}, - {DynamicAPInt(2)}), - b = parseDivisionRepr( - 2, 1, - {{DynamicAPInt(16), DynamicAPInt(32), DynamicAPInt(-16)}}, - {DynamicAPInt(32)}), - c = parseDivisionRepr(1, 1, - {{DynamicAPInt(12), DynamicAPInt(-4)}}, - {DynamicAPInt(8)}), - d = parseDivisionRepr( - 2, 2, - {{DynamicAPInt(1), DynamicAPInt(2), DynamicAPInt(-1)}, - {DynamicAPInt(4), DynamicAPInt(8), DynamicAPInt(-4)}}, - {DynamicAPInt(2), DynamicAPInt(8)}); + DivisionRepr a = parseDivisionRepr(2, 1, {{MPInt(1), MPInt(2), MPInt(-1)}}, + {MPInt(2)}), + b = parseDivisionRepr(2, 1, {{MPInt(16), MPInt(32), MPInt(-16)}}, + {MPInt(32)}), + c = parseDivisionRepr(1, 1, {{MPInt(12), MPInt(-4)}}, + {MPInt(8)}), + d = parseDivisionRepr(2, 2, + {{MPInt(1), MPInt(2), MPInt(-1)}, + {MPInt(4), MPInt(8), MPInt(-4)}}, + {MPInt(2), MPInt(8)}); b.removeDuplicateDivs(merge); c.removeDuplicateDivs(merge); d.removeDuplicateDivs(merge);