Skip to content

Commit

Permalink
[mlir][Presburger] Introduce Domain and Range identifiers in Presburg…
Browse files Browse the repository at this point in the history
…erSpace

This patch introducing seperating dimensions into two types: Domain and Range.
This allows building relations over PresburgerSpace.

This patch is part of a series of patches to introduce relations in Presburger
library.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D119709
  • Loading branch information
Groverkss committed Feb 18, 2022
1 parent 3ad0bda commit eae62b2
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 46 deletions.
95 changes: 80 additions & 15 deletions mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
Expand Up @@ -15,6 +15,7 @@
#define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H

#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"

namespace mlir {

Expand All @@ -31,23 +32,45 @@ class PresburgerLocalSpace;
///
/// Local: Local identifiers correspond to existentially quantified variables.
///
/// PresburgerSpace only supports identifiers of kind Dimension and Symbol.
/// Dimension identifiers are further divided into Domain and Range identifiers
/// to support building relations.
///
/// Spaces with distinction between domain and range identifiers should use
/// IdKind::Domain and IdKind::Range to refer to domain and range identifiers.
///
/// Spaces with no distinction between domain and range identifiers should use
/// IdKind::SetDim to refer to dimension identifiers.
///
/// PresburgerSpace does not support identifiers of kind Local. See
/// PresburgerLocalSpace for an extension that supports Local ids.
class PresburgerSpace {
friend PresburgerLocalSpace;

public:
/// Kind of identifier (column).
enum IdKind { Dimension, Symbol, Local };
/// Kind of identifier. Implementation wise SetDims are treated as Range
/// ids, and spaces with no distinction between dimension ids are treated
/// as relations with zero domain ids.
enum IdKind { Symbol, Local, Domain, Range, SetDim = Range };

PresburgerSpace(unsigned numDims, unsigned numSymbols)
: numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
static PresburgerSpace getRelationSpace(unsigned numDomain, unsigned numRange,
unsigned numSymbols);

static PresburgerSpace getSetSpace(unsigned numDims, unsigned numSymbols);

virtual ~PresburgerSpace() = default;

unsigned getNumIds() const { return numDims + numSymbols + numLocals; }
unsigned getNumDimIds() const { return numDims; }
unsigned getNumDomainIds() const { return numDomain; }
unsigned getNumRangeIds() const { return numRange; }
unsigned getNumSymbolIds() const { return numSymbols; }
unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
unsigned getNumSetDimIds() const { return numRange; }

unsigned getNumDimIds() const { return numDomain + numRange; }
unsigned getNumDimAndSymbolIds() const {
return numDomain + numRange + numSymbols;
}
unsigned getNumIds() const {
return numDomain + numRange + numSymbols + numLocals;
}

/// Get the number of ids of the specified kind.
unsigned getNumIdKind(IdKind kind) const;
Expand Down Expand Up @@ -78,12 +101,36 @@ class PresburgerSpace {
/// split become dimensions.
void setDimSymbolSeparation(unsigned newSymbolCount);

void print(llvm::raw_ostream &os) const;
void dump() const;

protected:
/// Space constructor for Relation space type.
PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols)
: PresburgerSpace(Relation, numDomain, numRange, numSymbols,
/*numLocals=*/0) {}

/// Space constructor for Set space type.
PresburgerSpace(unsigned numDims, unsigned numSymbols)
: PresburgerSpace(Set, /*numDomain=*/0, numDims, numSymbols,
/*numLocals=*/0) {}

private:
PresburgerSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals)
: numDims(numDims), numSymbols(numSymbols), numLocals(numLocals) {}
/// Kind of space.
enum SpaceKind { Set, Relation };

PresburgerSpace(SpaceKind spaceKind, unsigned numDomain, unsigned numRange,
unsigned numSymbols, unsigned numLocals)
: spaceKind(spaceKind), numDomain(numDomain), numRange(numRange),
numSymbols(numSymbols), numLocals(numLocals) {}

/// Number of identifiers corresponding to real dimensions.
unsigned numDims;
SpaceKind spaceKind;

// Number of identifiers corresponding to domain identifiers.
unsigned numDomain;

// Number of identifiers corresponding to range identifiers.
unsigned numRange;

/// Number of identifiers corresponding to symbols (unknown but constant for
/// analysis).
Expand All @@ -96,9 +143,13 @@ class PresburgerSpace {
/// Extension of PresburgerSpace supporting Local identifiers.
class PresburgerLocalSpace : public PresburgerSpace {
public:
PresburgerLocalSpace(unsigned numDims, unsigned numSymbols,
unsigned numLocals)
: PresburgerSpace(numDims, numSymbols, numLocals) {}
static PresburgerLocalSpace getRelationSpace(unsigned numDomain,
unsigned numRange,
unsigned numSymbols,
unsigned numLocals);

static PresburgerLocalSpace getSetSpace(unsigned numDims, unsigned numSymbols,
unsigned numLocals);

unsigned getNumLocalIds() const { return numLocals; }

Expand All @@ -110,6 +161,20 @@ class PresburgerLocalSpace : public PresburgerSpace {

/// Removes identifiers in the column range [idStart, idLimit).
void removeIdRange(unsigned idStart, unsigned idLimit) override;

void print(llvm::raw_ostream &os) const;
void dump() const;

protected:
/// Local Space constructor for Relation space type.
PresburgerLocalSpace(unsigned numDomain, unsigned numRange,
unsigned numSymbols, unsigned numLocals)
: PresburgerSpace(Relation, numDomain, numRange, numSymbols, numLocals) {}

/// Local Space constructor for Set space type.
PresburgerLocalSpace(unsigned numDims, unsigned numSymbols,
unsigned numLocals)
: PresburgerSpace(Set, /*numDomain=*/0, numDims, numSymbols, numLocals) {}
};

} // namespace mlir
Expand Down
13 changes: 6 additions & 7 deletions mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
Expand Up @@ -93,7 +93,7 @@ IntegerPolyhedron::getRationalLexMin() const {
}

unsigned IntegerPolyhedron::insertDimId(unsigned pos, unsigned num) {
return insertId(IdKind::Dimension, pos, num);
return insertId(IdKind::SetDim, pos, num);
}

unsigned IntegerPolyhedron::insertSymbolId(unsigned pos, unsigned num) {
Expand All @@ -107,16 +107,15 @@ unsigned IntegerPolyhedron::insertLocalId(unsigned pos, unsigned num) {
unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumIdKind(kind));

unsigned absolutePos = getIdKindOffset(kind) + pos;
inequalities.insertColumns(absolutePos, num);
equalities.insertColumns(absolutePos, num);

return PresburgerLocalSpace::insertId(kind, pos, num);
unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num);
inequalities.insertColumns(insertPos, num);
equalities.insertColumns(insertPos, num);
return insertPos;
}

unsigned IntegerPolyhedron::appendDimId(unsigned num) {
unsigned pos = getNumDimIds();
insertId(IdKind::Dimension, pos, num);
insertId(IdKind::SetDim, pos, num);
return pos;
}

Expand Down
99 changes: 81 additions & 18 deletions mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
Expand Up @@ -12,24 +12,56 @@

using namespace mlir;

PresburgerSpace PresburgerSpace::getRelationSpace(unsigned numDomain,
unsigned numRange,
unsigned numSymbols) {
return PresburgerSpace(numDomain, numRange, numSymbols);
}

PresburgerSpace PresburgerSpace::getSetSpace(unsigned numDims,
unsigned numSymbols) {
return PresburgerSpace(numDims, numSymbols);
}

PresburgerLocalSpace
PresburgerLocalSpace::getRelationSpace(unsigned numDomain, unsigned numRange,
unsigned numSymbols,
unsigned numLocals) {
return PresburgerLocalSpace(numDomain, numRange, numSymbols, numLocals);
}

PresburgerLocalSpace PresburgerLocalSpace::getSetSpace(unsigned numDims,
unsigned numSymbols,
unsigned numLocals) {
return PresburgerLocalSpace(numDims, numSymbols, numLocals);
}

unsigned PresburgerSpace::getNumIdKind(IdKind kind) const {
if (kind == IdKind::Dimension)
return getNumDimIds();
if (kind == IdKind::Domain) {
assert(spaceKind == Relation && "IdKind::Domain is not supported in Set.");
return getNumDomainIds();
}
if (kind == IdKind::Range)
return getNumRangeIds();
if (kind == IdKind::Symbol)
return getNumSymbolIds();
if (kind == IdKind::Local)
return numLocals;
llvm_unreachable("IdKind does not exit!");
llvm_unreachable("IdKind does not exist!");
}

unsigned PresburgerSpace::getIdKindOffset(IdKind kind) const {
if (kind == IdKind::Dimension)
if (kind == IdKind::Domain) {
assert(spaceKind == Relation && "IdKind::Domain is not supported in Set.");
return 0;
}
if (kind == IdKind::Range)
return getNumDomainIds();
if (kind == IdKind::Symbol)
return getNumDimIds();
if (kind == IdKind::Local)
return getNumDimAndSymbolIds();
llvm_unreachable("IdKind does not exit!");
llvm_unreachable("IdKind does not exist!");
}

unsigned PresburgerSpace::getIdKindEnd(IdKind kind) const {
Expand All @@ -56,13 +88,16 @@ unsigned PresburgerSpace::insertId(IdKind kind, unsigned pos, unsigned num) {

unsigned absolutePos = getIdKindOffset(kind) + pos;

if (kind == IdKind::Dimension)
numDims += num;
else if (kind == IdKind::Symbol)
if (kind == IdKind::Domain) {
assert(spaceKind == Relation && "IdKind::Domain is not supported in Set.");
numDomain += num;
} else if (kind == IdKind::Range) {
numRange += num;
} else if (kind == IdKind::Symbol) {
numSymbols += num;
else
llvm_unreachable(
"PresburgerSpace only supports Dimensions and Symbol identifiers!");
} else {
llvm_unreachable("PresburgerSpace does not support local identifiers!");
}

return absolutePos;
}
Expand All @@ -76,13 +111,17 @@ void PresburgerSpace::removeIdRange(unsigned idStart, unsigned idLimit) {
// We are going to be removing one or more identifiers from the range.
assert(idStart < getNumIds() && "invalid idStart position");

// Update members numDims, numSymbols and numIds.
unsigned numDimsEliminated =
getIdKindOverlap(IdKind::Dimension, idStart, idLimit);
// Update members numDomain, numRange, numSymbols and numIds.
unsigned numDomainEliminated = 0;
if (spaceKind == Relation)
numDomainEliminated = getIdKindOverlap(IdKind::Domain, idStart, idLimit);
unsigned numRangeEliminated =
getIdKindOverlap(IdKind::Range, idStart, idLimit);
unsigned numSymbolsEliminated =
getIdKindOverlap(IdKind::Symbol, idStart, idLimit);

numDims -= numDimsEliminated;
numDomain -= numDomainEliminated;
numRange -= numRangeEliminated;
numSymbols -= numSymbolsEliminated;
}

Expand All @@ -108,8 +147,7 @@ void PresburgerLocalSpace::removeIdRange(unsigned idStart, unsigned idLimit) {
getIdKindOverlap(IdKind::Local, idStart, idLimit);

// Update space parameters.
PresburgerSpace::removeIdRange(
idStart, std::min(idLimit, PresburgerSpace::getNumIds()));
PresburgerSpace::removeIdRange(idStart, idLimit);

// Update local ids.
numLocals -= numLocalsEliminated;
Expand All @@ -118,6 +156,31 @@ void PresburgerLocalSpace::removeIdRange(unsigned idStart, unsigned idLimit) {
void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) {
assert(newSymbolCount <= getNumDimAndSymbolIds() &&
"invalid separation position");
numDims = numDims + numSymbols - newSymbolCount;
numRange = numRange + numSymbols - newSymbolCount;
numSymbols = newSymbolCount;
}

void PresburgerSpace::print(llvm::raw_ostream &os) const {
if (spaceKind == Relation) {
os << "Domain: " << getNumDomainIds() << ", "
<< "Range: " << getNumRangeIds() << ", ";
} else {
os << "Dimension: " << getNumDomainIds() << ", ";
}
os << "Symbols: " << getNumSymbolIds() << "\n";
}

void PresburgerSpace::dump() const { print(llvm::errs()); }

void PresburgerLocalSpace::print(llvm::raw_ostream &os) const {
if (spaceKind == Relation) {
os << "Domain: " << getNumDomainIds() << ", "
<< "Range: " << getNumRangeIds() << ", ";
} else {
os << "Dimension: " << getNumDomainIds() << ", ";
}
os << "Symbols: " << getNumSymbolIds() << ", "
<< "Locals" << getNumLocalIds() << "\n";
}

void PresburgerLocalSpace::dump() const { print(llvm::errs()); }
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
Expand Up @@ -268,7 +268,7 @@ void FlatAffineValueConstraints::reset(unsigned newNumDims,

unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
unsigned pos = getNumDimIds();
insertId(IdKind::Dimension, pos, vals);
insertId(IdKind::SetDim, pos, vals);
return pos;
}

Expand All @@ -280,7 +280,7 @@ unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {

unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
ValueRange vals) {
return insertId(IdKind::Dimension, pos, vals);
return insertId(IdKind::SetDim, pos, vals);
}

unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
Expand Down Expand Up @@ -365,7 +365,7 @@ areIdsUnique(const FlatAffineConstraints &cst) {
static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
const FlatAffineValueConstraints &cst, FlatAffineConstraints::IdKind kind) {

if (kind == FlatAffineConstraints::IdKind::Dimension)
if (kind == FlatAffineConstraints::IdKind::SetDim)
return areIdsUnique(cst, 0, cst.getNumDimIds());
if (kind == FlatAffineConstraints::IdKind::Symbol)
return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds());
Expand Down Expand Up @@ -1214,8 +1214,8 @@ FlatAffineValueConstraints::computeAlignedMap(AffineMap map,

dims.reserve(getNumDimIds());
syms.reserve(getNumSymbolIds());
for (unsigned i = getIdKindOffset(IdKind::Dimension),
e = getIdKindEnd(IdKind::Dimension);
for (unsigned i = getIdKindOffset(IdKind::SetDim),
e = getIdKindEnd(IdKind::SetDim);
i < e; ++i)
dims.push_back(values[i] ? *values[i] : Value());
for (unsigned i = getIdKindOffset(IdKind::Symbol),
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Analysis/Presburger/CMakeLists.txt
Expand Up @@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
LinearTransformTest.cpp
MatrixTest.cpp
PresburgerSetTest.cpp
PresburgerSpaceTest.cpp
PWMAFunctionTest.cpp
SimplexTest.cpp
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
Expand Down
Expand Up @@ -158,7 +158,7 @@ TEST(IntegerPolyhedronTest, removeIdRange) {
EXPECT_THAT(set.getInequality(0),
testing::ElementsAre(10, 11, 12, 20, 30, 40));

set.removeIdRange(IntegerPolyhedron::IdKind::Dimension, 0, 2);
set.removeIdRange(IntegerPolyhedron::IdKind::SetDim, 0, 2);
EXPECT_THAT(set.getInequality(0), testing::ElementsAre(12, 20, 30, 40));

set.removeIdRange(IntegerPolyhedron::IdKind::Local, 1, 1);
Expand Down

0 comments on commit eae62b2

Please sign in to comment.