Skip to content

Commit

Permalink
[MLIR][Presburger] Use PresburgerSpace in constructors
Browse files Browse the repository at this point in the history
This patch modifies IntegerPolyhedron, IntegerRelation, PresburgerRelation,
PresburgerSet, PWMAFunction, constructors to take PresburgerSpace instead of
dimensions. This allows information present in PresburgerSpace to be carried
better and allows for a general interface.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D122842
  • Loading branch information
Groverkss committed Apr 1, 2022
1 parent a1901f5 commit a5a598b
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 167 deletions.
41 changes: 17 additions & 24 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Expand Up @@ -57,30 +57,24 @@ class IntegerRelation : public PresburgerSpace {
/// of constraints and identifiers.
IntegerRelation(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedCols,
unsigned numDomain, unsigned numRange, unsigned numSymbols,
unsigned numLocals)
: PresburgerSpace(numDomain, numRange, numSymbols, numLocals),
const PresburgerSpace &space)
: PresburgerSpace(space),
equalities(0, getNumIds() + 1, numReservedEqualities, numReservedCols),
inequalities(0, getNumIds() + 1, numReservedInequalities,
numReservedCols) {
assert(numReservedCols >= getNumIds() + 1);
}

/// Constructs a relation with the specified number of dimensions and symbols.
IntegerRelation(unsigned numDomain = 0, unsigned numRange = 0,
unsigned numSymbols = 0, unsigned numLocals = 0)
IntegerRelation(const PresburgerSpace &space)
: IntegerRelation(/*numReservedInequalities=*/0,
/*numReservedEqualities=*/0,
/*numReservedCols=*/numDomain + numRange + numSymbols +
numLocals + 1,
numDomain, numRange, numSymbols, numLocals) {}
/*numReservedCols=*/space.getNumIds() + 1, space) {}

/// Return a system with no constraints, i.e., one which is satisfied by all
/// points.
static IntegerRelation getUniverse(unsigned numDomain = 0,
unsigned numRange = 0,
unsigned numSymbols = 0) {
return IntegerRelation(numDomain, numRange, numSymbols);
static IntegerRelation getUniverse(const PresburgerSpace &space) {
return IntegerRelation(space);
}

/// Return the kind of this IntegerRelation.
Expand Down Expand Up @@ -562,25 +556,24 @@ class IntegerPolyhedron : public IntegerRelation {
/// of constraints and identifiers.
IntegerPolyhedron(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedCols,
unsigned numDims, unsigned numSymbols, unsigned numLocals)
const PresburgerSpace &space)
: IntegerRelation(numReservedInequalities, numReservedEqualities,
numReservedCols, /*numDomain=*/0, /*numRange=*/numDims,
numSymbols, numLocals) {}
numReservedCols, space) {
assert(space.getNumDomainIds() == 0 &&
"Number of domain id's should be zero in Set kind space.");
}

/// Constructs a relation with the specified number of dimensions and symbols.
IntegerPolyhedron(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0)
/// Constructs a relation with the specified number of dimensions and
/// symbols.
IntegerPolyhedron(const PresburgerSpace &space)
: IntegerPolyhedron(/*numReservedInequalities=*/0,
/*numReservedEqualities=*/0,
/*numReservedCols=*/numDims + numSymbols + numLocals +
1,
numDims, numSymbols, numLocals) {}
/*numReservedCols=*/space.getNumIds() + 1, space) {}

/// Return a system with no constraints, i.e., one which is satisfied by all
/// points.
static IntegerPolyhedron getUniverse(unsigned numDims = 0,
unsigned numSymbols = 0) {
return IntegerPolyhedron(numDims, numSymbols);
static IntegerPolyhedron getUniverse(const PresburgerSpace &space) {
return IntegerPolyhedron(space);
}

/// Return the kind of this IntegerRelation.
Expand Down
13 changes: 6 additions & 7 deletions mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
Expand Up @@ -57,9 +57,8 @@ class MultiAffineFunction : protected IntegerPolyhedron {

MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
: IntegerPolyhedron(domain), output(output) {}
MultiAffineFunction(const Matrix &output, unsigned numDims,
unsigned numSymbols = 0, unsigned numLocals = 0)
: IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
: IntegerPolyhedron(space), output(output) {}

~MultiAffineFunction() override = default;
Kind getKind() const override { return Kind::MultiAffineFunction; }
Expand Down Expand Up @@ -137,10 +136,10 @@ class MultiAffineFunction : protected IntegerPolyhedron {
/// finding the value of the function at a point.
class PWMAFunction : public PresburgerSpace {
public:
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
: PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols,
/*numLocals=*/0),
numOutputs(numOutputs) {
PWMAFunction(const PresburgerSpace &space, unsigned numOutputs)
: PresburgerSpace(space), numOutputs(numOutputs) {
assert(getNumDomainIds() == 0 && "Set type space should zero domain ids.");
assert(getNumLocalIds() == 0 && "PWMAFunction cannot have local ids.");
assert(numOutputs >= 1 && "The function must output something!");
}

Expand Down
26 changes: 13 additions & 13 deletions mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
Expand Up @@ -37,13 +37,10 @@ class SetCoalescer;
class PresburgerRelation : public PresburgerSpace {
public:
/// Return a universe set of the specified type that contains all points.
static PresburgerRelation getUniverse(unsigned numDomain, unsigned numRange,
unsigned numSymbols);
static PresburgerRelation getUniverse(const PresburgerSpace &space);

/// Return an empty set of the specified type that contains no points.
static PresburgerRelation getEmpty(unsigned numDomain = 0,
unsigned numRange = 0,
unsigned numSymbols = 0);
static PresburgerRelation getEmpty(const PresburgerSpace &space);

explicit PresburgerRelation(const IntegerRelation &disjunct);

Expand Down Expand Up @@ -119,9 +116,10 @@ class PresburgerRelation : public PresburgerSpace {
protected:
/// Construct an empty PresburgerRelation with the specified number of
/// dimension and symbols.
PresburgerRelation(unsigned numDomain = 0, unsigned numRange = 0,
unsigned numSymbols = 0)
: PresburgerSpace(numDomain, numRange, numSymbols, /*numLocals=*/0) {}
PresburgerRelation(const PresburgerSpace &space) : PresburgerSpace(space) {
assert(space.getNumLocalIds() == 0 &&
"PresburgerRelation cannot have local ids.");
}

/// The list of disjuncts that this set is the union of.
SmallVector<IntegerRelation, 2> integerRelations;
Expand All @@ -132,11 +130,10 @@ class PresburgerRelation : public PresburgerSpace {
class PresburgerSet : public PresburgerRelation {
public:
/// Return a universe set of the specified type that contains all points.
static PresburgerSet getUniverse(unsigned numDims = 0,
unsigned numSymbols = 0);
static PresburgerSet getUniverse(const PresburgerSpace &space);

/// Return an empty set of the specified type that contains no points.
static PresburgerSet getEmpty(unsigned numDims = 0, unsigned numSymbols = 0);
static PresburgerSet getEmpty(const PresburgerSpace &space);

/// Create a set from a relation.
explicit PresburgerSet(const IntegerPolyhedron &disjunct);
Expand All @@ -154,8 +151,11 @@ class PresburgerSet : public PresburgerRelation {
protected:
/// Construct an empty PresburgerRelation with the specified number of
/// dimension and symbols.
PresburgerSet(unsigned numDims = 0, unsigned numSymbols = 0)
: PresburgerRelation(/*numDomain=*/0, numDims, numSymbols) {}
PresburgerSet(const PresburgerSpace &space) : PresburgerRelation(space) {
assert(space.getNumDomainIds() == 0 && "Set type cannot have domain ids.");
assert(space.getNumLocalIds() == 0 &&
"PresburgerRelation cannot have local ids.");
}
};

} // namespace presburger
Expand Down
31 changes: 27 additions & 4 deletions mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
Expand Up @@ -64,10 +64,24 @@ enum class IdKind { Symbol, Local, Domain, Range, SetDim = Range };
/// identifiers of each kind are equal.
class PresburgerSpace {
public:
PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0,
unsigned numSymbols = 0, unsigned numLocals = 0)
: numDomain(numDomain), numRange(numRange), numSymbols(numSymbols),
numLocals(numLocals) {}
static PresburgerSpace getRelationSpace(unsigned numDomain = 0,
unsigned numRange = 0,
unsigned numSymbols = 0,
unsigned numLocals = 0) {
return PresburgerSpace(numDomain, numRange, numSymbols, numLocals);
}

static PresburgerSpace getSetSpace(unsigned numDims = 0,
unsigned numSymbols = 0,
unsigned numLocals = 0) {
return PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols,
numLocals);
}

PresburgerSpace getSpace() const { return *this; }
PresburgerSpace getCompatibleSpace() const {
return PresburgerSpace(numDomain, numRange, numSymbols);
}

virtual ~PresburgerSpace() = default;

Expand Down Expand Up @@ -99,6 +113,9 @@ class PresburgerSpace {
unsigned getIdKindOverlap(IdKind kind, unsigned idStart,
unsigned idLimit) const;

/// Return the IdKind of the id at the specified position.
IdKind getIdKindAt(unsigned pos) const;

/// Insert `num` identifiers of the specified kind at position `pos`.
/// Positions are relative to the kind of identifier. Return the absolute
/// column position (i.e., not relative to the kind of identifier) of the
Expand Down Expand Up @@ -131,6 +148,12 @@ class PresburgerSpace {
void print(llvm::raw_ostream &os) const;
void dump() const;

protected:
PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0,
unsigned numSymbols = 0, unsigned numLocals = 0)
: numDomain(numDomain), numRange(numRange), numSymbols(numSymbols),
numLocals(numLocals) {}

private:
// Number of identifiers corresponding to domain identifiers.
unsigned numDomain;
Expand Down
15 changes: 8 additions & 7 deletions mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
Expand Up @@ -65,18 +65,19 @@ class FlatAffineConstraints : public presburger::IntegerPolyhedron {
unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims,
unsigned numSymbols, unsigned numLocals)
: IntegerPolyhedron(numReservedInequalities, numReservedEqualities,
numReservedCols, numDims, numSymbols, numLocals) {}
: IntegerPolyhedron(
numReservedInequalities, numReservedEqualities, numReservedCols,
PresburgerSpace::getSetSpace(numDims, numSymbols, numLocals)) {}

/// Constructs a constraint system with the specified number of
/// dimensions and symbols.
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0)
: IntegerPolyhedron(/*numReservedInequalities=*/0,
/*numReservedEqualities=*/0,
/*numReservedCols=*/numDims + numSymbols + numLocals +
1,
numDims, numSymbols, numLocals) {}
: FlatAffineConstraints(/*numReservedInequalities=*/0,
/*numReservedEqualities=*/0,
/*numReservedCols=*/numDims + numSymbols +
numLocals + 1,
numDims, numSymbols, numLocals) {}

explicit FlatAffineConstraints(const IntegerPolyhedron &poly)
: IntegerPolyhedron(poly) {}
Expand Down
21 changes: 7 additions & 14 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Expand Up @@ -1702,20 +1702,14 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
}
}

// Set the number of dimensions, symbols, locals in the resulting system.
unsigned newNumDomain =
getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1);
unsigned newNumRange =
getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1);
unsigned newNumSymbols =
getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1);
unsigned newNumLocals =
getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1);
PresburgerSpace newSpace = getSpace();
IdKind idKindRemove = newSpace.getIdKindAt(pos);
unsigned relativePos = pos - newSpace.getIdKindOffset(idKindRemove);
newSpace.removeIdRange(idKindRemove, relativePos, relativePos + 1);

/// Create the new system which has one identifier less.
IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
getNumEqualities(), getNumCols() - 1, newNumDomain,
newNumRange, newNumSymbols, newNumLocals);
getNumEqualities(), getNumCols() - 1, newSpace);

// This will be used to check if the elimination was integer exact.
unsigned lcmProducts = 1;
Expand Down Expand Up @@ -1866,8 +1860,7 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
// Returns constraints that are common to both A & B.
static void getCommonConstraints(const IntegerRelation &a,
const IntegerRelation &b, IntegerRelation &c) {
c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(),
a.getNumSymbolIds(), a.getNumLocalIds());
c = IntegerRelation(a.getSpace());
// a naive O(n^2) check should be enough here given the input sizes.
for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
Expand Down Expand Up @@ -1896,7 +1889,7 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {

// Get the constraints common to both systems; these will be added as is to
// the union.
IntegerRelation commonCst;
IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
getCommonConstraints(*this, otherCst, commonCst);

std::vector<SmallVector<int64_t, 8>> boundingLbs;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Analysis/Presburger/LinearTransform.cpp
Expand Up @@ -113,8 +113,7 @@ LinearTransform::makeTransformToColumnEchelon(Matrix m) {
}

IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const {
IntegerRelation result(rel.getNumDomainIds(), rel.getNumRangeIds(),
rel.getNumSymbolIds(), rel.getNumLocalIds());
IntegerRelation result(rel.getSpace());

for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) {
ArrayRef<int64_t> eq = rel.getEquality(i);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Analysis/Presburger/PWMAFunction.cpp
Expand Up @@ -27,8 +27,7 @@ static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
}

PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain =
PresburgerSet::getEmpty(getNumDimIds(), getNumSymbolIds());
PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
for (const MultiAffineFunction &piece : pieces)
domain.unionInPlace(piece.getDomain());
return domain;
Expand Down

0 comments on commit a5a598b

Please sign in to comment.