diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h index 5e887fbf0cabb..c735ddd037c80 100644 --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -101,6 +101,25 @@ struct MaybeLocalRepr { } repr; }; +/// If `q` is defined to be equal to `expr floordiv d`, this equivalent to +/// saying that `q` is an integer and `q` is subject to the inequalities +/// `0 <= expr - d*q <= c - 1` (quotient remainder theorem). +/// +/// Rearranging, we get the bounds on `q`: d*q <= expr <= d*q + d - 1. +/// +/// `getDivUpperBound` returns `d*q <= expr`, and +/// `getDivLowerBound` returns `expr <= d*q + d - 1`. +/// +/// The parameter `dividend` corresponds to `expr` above, `divisor` to `d`, and +/// `localVarIdx` to the position of `q` in the coefficient list. +/// +/// The coefficient of `q` in `dividend` must be zero, as it is not allowed for +/// local variable to be a floor division of an expression involving itself. +SmallVector getDivUpperBound(ArrayRef dividend, + int64_t divisor, unsigned localVarIdx); +SmallVector getDivLowerBound(ArrayRef dividend, + int64_t divisor, unsigned localVarIdx); + /// Check if the pos^th variable can be expressed as a floordiv of an affine /// function of other variables (where the divisor is a positive constant). /// `foundRepr` contains a boolean for each variable indicating if the diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index d9c3c86105b2a..ef23acbab8f91 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1326,21 +1326,12 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef dividend, appendVar(VarKind::Local); - // Add two constraints for this new variable 'q'. - SmallVector bound(dividend.size() + 1); - - // dividend - q * divisor >= 0 - std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, - bound.begin()); - bound.back() = dividend.back(); - bound[getNumVars() - 1] = -divisor; - addInequality(bound); - - // -dividend +qdivisor * q + divisor - 1 >= 0 - std::transform(bound.begin(), bound.end(), bound.begin(), - std::negate()); - bound[bound.size() - 1] += divisor - 1; - addInequality(bound); + SmallVector dividendCopy(dividend.begin(), dividend.end()); + dividendCopy.insert(dividendCopy.end() - 1, 0); + addInequality( + getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); + addInequality( + getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); } /// Finds an equality that equates the specified variable to a constant. @@ -2281,4 +2272,4 @@ unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos, assert((kind != VarKind::Domain || num == 0) && "Domain has to be zero in a set"); return IntegerRelation::insertVar(kind, pos, num); -} \ No newline at end of file +} diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 9131aaeed3ac5..8b14df655262b 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -228,30 +228,37 @@ static PresburgerRelation getSetDifference(IntegerRelation b, // Similarly, we also want to rollback simplex to its original state. unsigned initialSnapshot = simplex.getSnapshot(); - // Find out which inequalities of sI correspond to division inequalities - // for the local variables of sI. - std::vector repr(sI.getNumLocalVars()); - sI.getLocalReprs(repr); - // Add sI's locals to b, after b's locals. Only those locals of sI which // do not already exist in b will be added. (i.e., duplicate divisions // will not be added.) Also add b's locals to sI, in such a way that both // have the same locals in the same order in the end. b.mergeLocalVars(sI); + // Find out which inequalities of sI correspond to division inequalities + // for the local variables of sI. + // + // Careful! This has to be done after the merge above; otherwise, the + // dividends won't contain the new ids inserted during the merge. + std::vector repr; + std::vector> dividends; + SmallVector divisors; + sI.getLocalReprs(dividends, divisors, repr); + // Mark which inequalities of sI are division inequalities and add all // such inequalities to b. llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() + 2 * sI.getNumEqualities()); - for (MaybeLocalRepr &maybeRepr : repr) { + for (unsigned i = initBCounts.getSpace().getNumLocalVars(), + e = sI.getNumLocalVars(); + i < e; ++i) { assert( - maybeRepr && + repr[i] && "Subtraction is not supported when a representation of the local " "variables of the subtrahend cannot be found!"); - if (maybeRepr.kind == ReprKind::Inequality) { - unsigned lb = maybeRepr.repr.inequalityPair.lowerBoundIdx; - unsigned ub = maybeRepr.repr.inequalityPair.upperBoundIdx; + if (repr[i].kind == ReprKind::Inequality) { + unsigned lb = repr[i].repr.inequalityPair.lowerBoundIdx; + unsigned ub = repr[i].repr.inequalityPair.upperBoundIdx; b.addInequality(sI.getInequality(lb)); b.addInequality(sI.getInequality(ub)); @@ -261,14 +268,30 @@ static PresburgerRelation getSetDifference(IntegerRelation b, canIgnoreIneq[lb] = true; canIgnoreIneq[ub] = true; } else { - assert(maybeRepr.kind == ReprKind::Equality && + assert(repr[i].kind == ReprKind::Equality && "ReprKind isn't inequality so should be equality"); - unsigned idx = maybeRepr.repr.equalityIdx; - b.addEquality(sI.getEquality(idx)); - // We can ignore both inequalities corresponding to this equality. - unsigned offset = sI.getNumInequalities(); - canIgnoreIneq[offset + 2 * idx] = true; - canIgnoreIneq[offset + 2 * idx + 1] = true; + + // Consider the case (x) : (x = 3e + 1), where e is a local. + // Its complement is (x) : (x = 3e) or (x = 3e + 2). + // + // This can be computed by considering the set to be + // (x) : (x = 3*(x floordiv 3) + 1). + // + // Now there are no equalities defining divisions; the division is + // defined by the standard division equalities for e = x floordiv 3, + // i.e., 0 <= x - 3*e <= 2. + // So now as before, we add these division inequalities to b. The + // equality is now just an ordinary constraint that must be considered + // in the remainder of the algorithm. The division inequalities must + // need not be considered, same as above, and they automatically will + // not be because they were never a part of sI; we just infer them + // from the equality and add them only to b. + b.addInequality( + getDivLowerBound(dividends[i], divisors[i], + sI.getVarKindOffset(VarKind::Local) + i)); + b.addInequality( + getDivUpperBound(dividends[i], divisors[i], + sI.getVarKindOffset(VarKind::Local) + i)); } } diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp index e985c821fb9f6..199261789f703 100644 --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -338,6 +338,29 @@ void presburger::mergeLocalVars( presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge); } +SmallVector presburger::getDivUpperBound(ArrayRef dividend, + int64_t divisor, + unsigned localVarIdx) { + assert(dividend[localVarIdx] == 0 && + "Local to be set to division must have zero coeff!"); + SmallVector ineq(dividend.begin(), dividend.end()); + ineq[localVarIdx] = -divisor; + return ineq; +} + +SmallVector presburger::getDivLowerBound(ArrayRef dividend, + int64_t divisor, + unsigned localVarIdx) { + assert(dividend[localVarIdx] == 0 && + "Local to be set to division must have zero coeff!"); + SmallVector ineq(dividend.size()); + std::transform(dividend.begin(), dividend.end(), ineq.begin(), + std::negate()); + ineq[localVarIdx] = divisor; + ineq.back() += divisor - 1; + return ineq; +} + int64_t presburger::gcdRange(ArrayRef range) { int64_t gcd = 0; for (int64_t elem : range) { diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp index 07d4565d116ce..02f801ae98b7b 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -459,11 +459,50 @@ TEST(SetTest, divisions) { PresburgerSet setA{parsePoly("(x) : (-x >= 0)")}; PresburgerSet setB{parsePoly("(x) : (x floordiv 2 - 4 >= 0)")}; EXPECT_TRUE(setA.subtract(setB).isEqual(setA)); +} + +void convertSuffixDimsToLocals(IntegerPolyhedron &poly, unsigned numLocals) { + poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numLocals, + poly.getNumDimVars(), VarKind::Local); +} + +inline IntegerPolyhedron parsePolyAndMakeLocals(StringRef str, + unsigned numLocals) { + IntegerPolyhedron poly = parsePoly(str); + convertSuffixDimsToLocals(poly, numLocals); + return poly; +} + +TEST(SetTest, divisionsDefByEq) { + // evens = {x : exists q, x = 2q}. + PresburgerSet evens{ + parsePolyAndMakeLocals("(x, y) : (x - 2 * y == 0)", /*numLocals=*/1)}; + + // odds = {x : exists q, x = 2q + 1}. + PresburgerSet odds{ + parsePolyAndMakeLocals("(x, y) : (x - 2 * y - 1 == 0)", /*numLocals=*/1)}; + + // multiples3 = {x : exists q, x = 3q}. + PresburgerSet multiples3{ + parsePolyAndMakeLocals("(x, y) : (x - 3 * y == 0)", /*numLocals=*/1)}; + + // multiples6 = {x : exists q, x = 6q}. + PresburgerSet multiples6{ + parsePolyAndMakeLocals("(x, y) : (x - 6 * y == 0)", /*numLocals=*/1)}; + + // evens /\ odds = empty. + expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds))); + // evens U odds = universe. + expectEqual(evens.unionSet(odds), + PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1)))); + expectEqual(evens.complement(), odds); + expectEqual(odds.complement(), evens); + // even multiples of 3 = multiples of 6. + expectEqual(multiples3.intersect(evens), multiples6); - IntegerPolyhedron evensDefByEquality(PresburgerSpace::getSetSpace( - /*numDims=*/1, /*numSymbols=*/0, /*numLocals=*/1)); - evensDefByEquality.addEquality({1, -2, 0}); - expectEqual(evens, PresburgerSet(evensDefByEquality)); + PresburgerSet evensDefByIneq{ + parsePoly("(x) : (x - 2 * (x floordiv 2) == 0)")}; + expectEqual(evens, PresburgerSet(evensDefByIneq)); } TEST(SetTest, subtractDuplicateDivsRegression) {