diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h index 74127a900d53e..a410f528e1f80 100644 --- a/mlir/include/mlir/Analysis/Presburger/Fraction.h +++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h @@ -30,7 +30,8 @@ struct Fraction { Fraction() = default; /// Construct a Fraction from a numerator and denominator. - Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) { + Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) + : num(oNum), den(oDen) { if (den < 0) { num = -num; den = -den; @@ -38,7 +39,8 @@ struct Fraction { } /// Overloads for passing literals. Fraction(const MPInt &num, int64_t den = 1) : Fraction(num, MPInt(den)) {} - Fraction(int64_t num, const MPInt &den = MPInt(1)) : Fraction(MPInt(num), den) {} + Fraction(int64_t num, const MPInt &den = MPInt(1)) + : Fraction(MPInt(num), den) {} Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {} // Return the value of the fraction as an integer. This should only be called @@ -102,7 +104,7 @@ inline bool operator>=(const Fraction &x, const Fraction &y) { inline Fraction reduce(const Fraction &f) { if (f == Fraction(0)) return Fraction(0, 1); - MPInt g = gcd(f.num, f.den); + MPInt g = gcd(abs(f.num), abs(f.den)); return Fraction(f.num / g, f.den / g); } @@ -122,22 +124,22 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) { return reduce(Fraction(x.num * y.den - x.den * y.num, x.den * y.den)); } -inline Fraction& operator+=(Fraction &x, const Fraction &y) { +inline Fraction &operator+=(Fraction &x, const Fraction &y) { x = x + y; return x; } -inline Fraction& operator-=(Fraction &x, const Fraction &y) { +inline Fraction &operator-=(Fraction &x, const Fraction &y) { x = x - y; return x; } -inline Fraction& operator/=(Fraction &x, const Fraction &y) { +inline Fraction &operator/=(Fraction &x, const Fraction &y) { x = x / y; return x; } -inline Fraction& operator*=(Fraction &x, const Fraction &y) { +inline Fraction &operator*=(Fraction &x, const Fraction &y) { x = x * y; return x; } diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt index 7b0124ee24c35..b6ce273e35a0e 100644 --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_unittest(MLIRPresburgerTests + FractionTest.cpp IntegerPolyhedronTest.cpp IntegerRelationTest.cpp LinearTransformTest.cpp diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp new file mode 100644 index 0000000000000..5fee9de1994c8 --- /dev/null +++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp @@ -0,0 +1,51 @@ +#include "mlir/Analysis/Presburger/Fraction.h" +#include "./Utils.h" +#include +#include + +using namespace mlir; +using namespace presburger; + +TEST(FractionTest, getAsInteger) { + Fraction f(3, 1); + EXPECT_EQ(f.getAsInteger(), MPInt(3)); +} + +TEST(FractionTest, nearIntegers) { + Fraction f(52, 14); + + EXPECT_EQ(floor(f), 3); + EXPECT_EQ(ceil(f), 4); +} + +TEST(FractionTest, reduce) { + Fraction f(20, 35), g(-56, 63); + EXPECT_EQ(f, Fraction(4, 7)); + EXPECT_EQ(g, Fraction(-8, 9)); +} + +TEST(FractionTest, arithmetic) { + Fraction f(3, 4), g(-2, 3); + + EXPECT_EQ(f / g, Fraction(-9, 8)); + EXPECT_EQ(f * g, Fraction(-1, 2)); + EXPECT_EQ(f + g, Fraction(1, 12)); + EXPECT_EQ(f - g, Fraction(17, 12)); + + f /= g; + EXPECT_EQ(f, Fraction(-9, 8)); + f *= g; + EXPECT_EQ(f, Fraction(3, 4)); + f += g; + EXPECT_EQ(f, Fraction(Fraction(1, 12))); + f -= g; + EXPECT_EQ(f, Fraction(3, 4)); +} + +TEST(FractionTest, relational) { + Fraction f(2, 5), g(3, 7); + EXPECT_TRUE(f < g); + EXPECT_FALSE(g < f); + + EXPECT_EQ(f, Fraction(4, 10)); +}