Skip to content

Commit 1ba6043

Browse files
committed
[MLIR][Presburger] Refactor subtraction in preparation for making it iterative
Refactor the operation of subtraction by - removing the usage of SimplexRollbackScopeExit since this can't be used in the iterative version - reducing the number of stack variables to make the iterative version easier to follow Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D123156
1 parent 9be6e7b commit 1ba6043

File tree

1 file changed

+71
-60
lines changed

1 file changed

+71
-60
lines changed

mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Lines changed: 71 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,35 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
100100
return result;
101101
}
102102

103+
/// Return the coefficients of the ineq in `rel` specified by `idx`.
104+
/// `idx` can refer not only to an actual inequality of `rel`, but also
105+
/// to either of the inequalities that make up an equality in `rel`.
106+
///
107+
/// When 0 <= idx < rel.getNumInequalities(), this returns the coeffs of the
108+
/// idx-th inequality of `rel`.
109+
///
110+
/// Otherwise, it is then considered to index into the ineqs corresponding to
111+
/// eqs of `rel`, and it must hold that
112+
///
113+
/// 0 <= idx - rel.getNumInequalities() < 2*getNumEqualities().
114+
///
115+
/// For every eq `coeffs == 0` there are two possible ineqs to index into.
116+
/// The first is coeffs >= 0 and the second is coeffs <= 0.
117+
static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
118+
unsigned idx) {
119+
assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
120+
"idx out of bounds!");
121+
if (idx < rel.getNumInequalities())
122+
return llvm::to_vector<8>(rel.getInequality(idx));
123+
124+
idx -= rel.getNumInequalities();
125+
ArrayRef<int64_t> eqCoeffs = rel.getEquality(idx / 2);
126+
127+
if (idx % 2 == 0)
128+
return llvm::to_vector<8>(eqCoeffs);
129+
return getNegatedCoeffs(eqCoeffs);
130+
}
131+
103132
/// Return the set difference b \ s and accumulate the result into `result`.
104133
/// `simplex` must correspond to b.
105134
///
@@ -133,15 +162,13 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
133162
/// that some constraints are redundant. These redundant constraints are
134163
/// ignored.
135164
///
136-
/// b and simplex are callee saved, i.e., their values on return are
137-
/// semantically equivalent to their values when the function is called.
138-
///
139165
/// b should not have duplicate divs because this might lead to existing
140166
/// divs disappearing in the call to mergeLocalIds below, which cannot be
141167
/// handled.
142168
static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
143169
const PresburgerRelation &s, unsigned i,
144170
PresburgerRelation &result) {
171+
145172
if (i == s.getNumDisjuncts()) {
146173
result.unionInPlace(b);
147174
return;
@@ -156,17 +183,9 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
156183
// rollback b to its initial state before returning, which we will do by
157184
// removing all constraints beyond the original number of inequalities
158185
// and equalities, so we store these counts first.
159-
const IntegerRelation::CountsSnapshot bCounts = b.getCounts();
186+
IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
160187
// Similarly, we also want to rollback simplex to its original state.
161-
const unsigned initialSnapshot = simplex.getSnapshot();
162-
163-
auto restoreState = [&]() {
164-
b.truncate(bCounts);
165-
simplex.rollback(initialSnapshot);
166-
};
167-
168-
// Automatically restore the original state when we return.
169-
auto stateRestorer = llvm::make_scope_exit(restoreState);
188+
unsigned initialSnapshot = simplex.getSnapshot();
170189

171190
// Find out which inequalities of sI correspond to division inequalities for
172191
// the local variables of sI.
@@ -176,105 +195,97 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
176195
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
177196
// sI's locals.
178197
b.mergeLocalIds(sI);
198+
unsigned numLocalsAdded =
199+
b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
200+
// Update simplex to also include the new locals in `b` from merging.
201+
simplex.appendVariable(numLocalsAdded);
179202

180-
// Mark which inequalities of sI are division inequalities and add all such
181-
// inequalities to b.
182-
llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
203+
// Equalities are processed by considering them as a pair of inequalities.
204+
// The first sI.getNumInequalities() elements are for sI's inequalities;
205+
// then a pair of inequalities occurs for each of sI's equalities.
206+
// If the equality is expr == 0, the first element in the pair
207+
// corresponds to expr >= 0, and the second to expr <= 0.
208+
llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
209+
2 * sI.getNumEqualities());
210+
211+
// Add all division inequalities to `b`.
183212
for (MaybeLocalRepr &maybeInequality : repr) {
184213
assert(maybeInequality.kind == ReprKind::Inequality &&
185214
"Subtraction is not supported when a representation of the local "
186215
"variables of the subtrahend cannot be found!");
187-
auto lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
188-
auto ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
216+
unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
217+
unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
189218

190219
b.addInequality(sI.getInequality(lb));
191220
b.addInequality(sI.getInequality(ub));
192221

193222
assert(lb != ub &&
194223
"Upper and lower bounds must be different inequalities!");
195-
isDivInequality[lb] = true;
196-
isDivInequality[ub] = true;
224+
225+
// We just added these inequalities to `b`, so there is no point considering
226+
// the parts where these inequalities occur complemented -- such parts are
227+
// empty. Therefore, we mark that these can be ignored.
228+
canIgnoreIneq[lb] = true;
229+
canIgnoreIneq[ub] = true;
197230
}
198231

199232
unsigned offset = simplex.getNumConstraints();
200-
unsigned numLocalsAdded =
201-
b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds();
202-
simplex.appendVariable(numLocalsAdded);
203-
204233
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
205234
simplex.intersectIntegerRelation(sI);
206235

207236
if (simplex.isEmpty()) {
208237
// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
209238
// We are ignoring level i completely, so we restore the state
210239
// *before* going to level i + 1.
211-
restoreState();
240+
b.truncate(initBCounts);
241+
simplex.rollback(initialSnapshot);
212242
subtractRecursively(b, simplex, s, i + 1, result);
213-
214-
// We already restored the state above and the recursive call should have
215-
// restored to the same state before returning, so we don't need to restore
216-
// the state again.
217-
stateRestorer.release();
218243
return;
219244
}
220245

221246
simplex.detectRedundant();
222247

223-
// Equalities are added to simplex as a pair of inequalities.
224248
unsigned totalNewSimplexInequalities =
225249
2 * sI.getNumEqualities() + sI.getNumInequalities();
226-
llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
250+
// Redundant inequalities can be safely ignored. This is not required for
251+
// correctness but improves performance and results in a more compact
252+
// representation of the set difference.
227253
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
228-
isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
229-
254+
canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
230255
simplex.rollback(snapshotBeforeIntersect);
231256

257+
SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
258+
for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
259+
if (!canIgnoreIneq[i])
260+
ineqsToProcess.push_back(i);
261+
232262
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
233263
// subtractRecursively. At the time this function is called, the current b is
234264
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
235265
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
236266
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
237267
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
238-
SimplexRollbackScopeExit scopeExit(simplex);
239268
b.addInequality(ineq);
240269
simplex.addInequality(ineq);
241270
subtractRecursively(b, simplex, s, i + 1, result);
242-
b.removeInequality(b.getNumInequalities() - 1);
243271
};
244272

245273
// For each inequality ineq, we first recurse with the part where ineq
246274
// is not satisfied, and then add the ineq to b and simplex because
247275
// ineq must be satisfied by all later parts.
248276
auto processInequality = [&](ArrayRef<int64_t> ineq) {
277+
unsigned snapshot = simplex.getSnapshot();
278+
IntegerRelation::CountsSnapshot bCounts = b.getCounts();
249279
recurseWithInequality(getComplementIneq(ineq));
280+
simplex.rollback(snapshot);
281+
b.truncate(bCounts);
282+
250283
b.addInequality(ineq);
251284
simplex.addInequality(ineq);
252285
};
253286

254-
// Process all the inequalities, ignoring redundant inequalities and division
255-
// inequalities. The result is correct whether or not we ignore these, but
256-
// ignoring them makes the result simpler.
257-
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
258-
if (isMarkedRedundant[j])
259-
continue;
260-
if (isDivInequality[j])
261-
continue;
262-
processInequality(sI.getInequality(j));
263-
}
264-
265-
offset = sI.getNumInequalities();
266-
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
267-
ArrayRef<int64_t> coeffs = sI.getEquality(j);
268-
// For each equality, process the positive and negative inequalities that
269-
// make up this equality. If Simplex found an inequality to be redundant, we
270-
// skip it as above to make the result simpler. Divisions are always
271-
// represented in terms of inequalities and not equalities, so we do not
272-
// check for division inequalities here.
273-
if (!isMarkedRedundant[offset + 2 * j])
274-
processInequality(coeffs);
275-
if (!isMarkedRedundant[offset + 2 * j + 1])
276-
processInequality(getNegatedCoeffs(coeffs));
277-
}
287+
for (unsigned idx : ineqsToProcess)
288+
processInequality(getIneqCoeffsFromIdx(sI, idx));
278289
}
279290

280291
/// Return the set difference disjunct \ set.

0 commit comments

Comments
 (0)