diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 4b1802413f75f..335a2dddc7561 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -485,6 +485,22 @@ class IntegerRelation { addLocalFloorDiv(getDynamicAPIntVec(dividend), DynamicAPInt(divisor)); } + /// Adds a new local variable as the modulus of an affine function of other + /// variables, the coefficients of which are provided in `exprs`. The modulus + /// is with respect to a positive constant `modulus`. The function returns the + /// absolute index of the new local variable representing the result of the + /// modulus operation. Two new local variables are added to the system, one + /// representing the floor div with respect to the modulus and one + /// representing the mod. Three constraints are added to the system to capture + /// the equivalance. The first two are required to compute the result of the + /// floor division `q`, and the third computes the equality relation: + /// result = exprs - modulus * q. + unsigned addLocalModulo(ArrayRef exprs, + const DynamicAPInt &modulus); + unsigned addLocalModulo(ArrayRef exprs, int64_t modulus) { + return addLocalModulo(getDynamicAPIntVec(exprs), DynamicAPInt(modulus)); + } + /// Projects out (aka eliminates) `num` variables starting at position /// `pos`. The resulting constraint system is the shadow along the dimensions /// that still exist. This method may not always be integer exact. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 5c4d4d13580a0..1d1e4ded19db1 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1515,6 +1515,27 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); } +unsigned IntegerRelation::addLocalModulo(ArrayRef exprs, + const DynamicAPInt &modulus) { + assert(exprs.size() == getNumCols() && "incorrect exprs size"); + assert(modulus > 0 && "positive modulus expected"); + + /// Add a local variable for q = expr floordiv modulus + addLocalFloorDiv(exprs, modulus); + + /// Add a local var to represent the result + auto resultIndex = appendVar(VarKind::Local); + + SmallVector exprsCopy(exprs); + /// Insert the two new locals before the constant + /// Add locals that correspond to `q` and `result` to compute + /// 0 = (expr - modulus * q) - result + exprsCopy.insert(exprsCopy.end() - 1, + {DynamicAPInt(-modulus), DynamicAPInt(-1)}); + addEquality(exprsCopy); + return resultIndex; +} + int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const { assert(pos < getNumVars() && "invalid position"); for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index a6ed5c5b21e79..9ae90a4841f3c 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -714,3 +714,14 @@ TEST(IntegerRelationTest, getVarKindRange) { } EXPECT_THAT(actual, ElementsAre(2, 3, 4)); } + +TEST(IntegerRelationTest, addLocalModulo) { + IntegerRelation rel = parseRelationFromSet("(x) : (x >= 0, 100 - x >= 0)", 1); + unsigned result = rel.addLocalModulo({1, 0}, 32); // x % 32 + rel.convertVarKind(VarKind::Local, + result - rel.getVarKindOffset(VarKind::Local), + rel.getNumVarKind(VarKind::Local), VarKind::Range); + for (unsigned x = 0; x <= 100; ++x) { + EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32})); + } +}