Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Presburger] Implement computation of generating function for unimodular cones #77235

Merged
merged 11 commits into from Jan 10, 2024
6 changes: 6 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Barvinok.h
Expand Up @@ -24,6 +24,7 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_BARVINOK_H
#define MLIR_ANALYSIS_PRESBURGER_BARVINOK_H

#include "mlir/Analysis/Presburger/GeneratingFunction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h"
#include <optional>
Expand Down Expand Up @@ -77,6 +78,11 @@ ConeV getDual(ConeH cone);
/// The returned cone is pointed at the origin.
ConeH getDual(ConeV cone);

/// Compute the generating function for a unimodular cone.
/// The input cone must be unimodular; it assert-fails otherwise.
GeneratingFunction unimodularConeGeneratingFunction(ParamPoint vertex, int sign,
ConeH cone);

} // namespace detail
} // namespace presburger
} // namespace mlir
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Expand Up @@ -221,6 +221,8 @@ class IntegerRelation {
return getInt64Vec(inequalities.getRow(idx));
}

inline IntMatrix getInequalities() const { return inequalities; }

/// Get the number of vars of the specified kind.
unsigned getNumVarKind(VarKind kind) const {
return space.getNumVarKind(kind);
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Matrix.h
Expand Up @@ -181,6 +181,9 @@ class Matrix {
/// `elems` must be equal to the number of columns.
unsigned appendExtraRow(ArrayRef<T> elems);

// Transpose the matrix without modifying it.
Matrix<T> transpose() const;

/// Print the matrix.
void print(raw_ostream &os) const;
void dump() const;
Expand Down
80 changes: 79 additions & 1 deletion mlir/lib/Analysis/Presburger/Barvinok.cpp
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Presburger/Barvinok.h"
#include "llvm/ADT/Sequence.h"

using namespace mlir;
using namespace presburger;
Expand All @@ -24,7 +25,7 @@ ConeV mlir::presburger::detail::getDual(ConeH cone) {
// is represented as a row [a1, ..., an, b]
// and that b = 0.

for (unsigned i = 0; i < numIneq; ++i) {
for (auto i : llvm::seq<int>(0, numIneq)) {
assert(cone.atIneq(i, numVar) == 0 &&
"H-representation of cone is not centred at the origin!");
for (unsigned j = 0; j < numVar; ++j) {
Expand Down Expand Up @@ -63,3 +64,80 @@ MPInt mlir::presburger::detail::getIndex(ConeV cone) {

return cone.determinant();
}

/// Compute the generating function for a unimodular cone.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
/// This consists of a single term of the form
/// sign * x^num / prod_j (1 - x^den_j)
///
/// sign is either +1 or -1.
/// den_j is defined as the set of generators of the cone.
/// num is computed by expressing the vertex as a weighted
/// sum of the generators, and then taking the floor of the
/// coefficients.
GeneratingFunction mlir::presburger::detail::unimodularConeGeneratingFunction(
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
ParamPoint vertex, int sign, ConeH cone) {
// `cone` must be unimodular.
assert(getIndex(getDual(cone)) == 1 && "input cone is not unimodular!");

unsigned numVar = cone.getNumVars();
unsigned numIneq = cone.getNumInequalities();

// Thus its ray matrix, U, is the inverse of the
// transpose of its inequality matrix, `cone`.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// The last column of the inequality matrix is null,
// so we remove it to obtain a square matrix.
FracMatrix transp = FracMatrix(cone.getInequalities()).transpose();
transp.removeRow(numVar);

FracMatrix generators(numVar, numIneq);
transp.determinant(/*inverse=*/&generators); // This is the U-matrix.

// The powers in the denominator of the generating
// function are given by the generators of the cone,
// i.e., the rows of the matrix U.
std::vector<Point> denominator(numIneq);
ArrayRef<Fraction> row;
for (auto i : llvm::seq<int>(0, numVar)) {
row = generators.getRow(i);
denominator[i] = Point(row);
}

// The vertex is v \in Z^{d x (n+1)}
// We need to find affine functions of parameters λ_i(p)
// such that v = Σ λ_i(p)*u_i.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// The λi are given by the columns of Λ = v^T U^{-1} = v^T transp.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// Then the exponent in the numerator will be
// Σ -floor(-λ_i(p))*u_i.
// Thus we store the (exponent of the) numerator as the affine function -Λ,
// since the generators are already stored as the exponent of the denominator.
// Note that the outer -1 will have to be accounted for, as it is not stored.
// See end for an example.

unsigned numColumns = vertex.getNumColumns();
unsigned numRows = vertex.getNumRows();
ParamPoint numerator(numColumns, numRows);
SmallVector<Fraction> ithCol(numRows);
for (auto i : llvm::seq<int>(0, numColumns)) {
for (auto j : llvm::seq<int>(0, numRows))
ithCol[j] = vertex(j, i);
numerator.setRow(i, transp.preMultiplyWithRow(ithCol));
numerator.negateRow(i);
}

return GeneratingFunction(numColumns - 1, SmallVector<int>(1, sign),
std::vector({numerator}),
std::vector({denominator}));

// Suppose the vertex is given by the matrix [ 2 2 0], with 2 params
// [-1 -1/2 1]
// and the cone has H-representation [0 -1] => U-matrix [ 2 -1]
// [-1 -2] [-1 0]
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// Therefore Λ will be given by [ 1 0 ] and the negation of this will be
// stored as the numerator.
// [ 1/2 -1 ]
// [ -1 -2 ]

// Algebraically, the numerator exponent is
// [ -2 ⌊ -N - M/2 +1 ⌋ + 1 ⌊ 0 +M +2 ⌋ ] -> first COLUMN of U is [2, -1]
// [ 1 ⌊ -N - M/2 +1 ⌋ + 0 ⌊ 0 +M +2 ⌋ ] -> second COLUMN of U is [-1, 0]
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
}
10 changes: 10 additions & 0 deletions mlir/lib/Analysis/Presburger/Matrix.cpp
Expand Up @@ -62,6 +62,16 @@ unsigned Matrix<T>::appendExtraRow(ArrayRef<T> elems) {
return row;
}

template <typename T>
Matrix<T> Matrix<T>::transpose() {
Matrix<T> transp(nColumns, nRows);
for (unsigned row = 0; row < nRows; ++row)
for (unsigned col = 0; col < nColumns; ++col)
transp(col, row) = at(row, col);

return transp;
}

template <typename T>
void Matrix<T>::resizeHorizontally(unsigned newNColumns) {
if (newNColumns < nColumns)
Expand Down
36 changes: 36 additions & 0 deletions mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
Expand Up @@ -46,3 +46,39 @@ TEST(BarvinokTest, getIndex) {
4, 4, {{4, 2, 5, 1}, {4, 1, 3, 6}, {8, 2, 5, 6}, {5, 2, 5, 7}});
EXPECT_EQ(getIndex(cone), cone.determinant());
}

// The following cones and vertices are randomly generated
// (s.t. the cones are unimodular) and the generating functions
// are computed. We check that the results contain the correct
// matrices.
TEST(BarvinokTest, unimodularConeGeneratingFunction) {
ConeH cone = defineHRep(2);
cone.addInequality({0, -1, 0});
cone.addInequality({-1, -2, 0});

ParamPoint vertex =
makeFracMatrix(2, 3, {{2, 2, 0}, {-1, -Fraction(1, 2), 1}});

GeneratingFunction gf = unimodularConeGeneratingFunction(vertex, 1, cone);

EXPECT_EQ_REPR_GENERATINGFUNCTION(
gf, GeneratingFunction(
2, {1},
{makeFracMatrix(3, 2, {{-1, 0}, {-Fraction(1, 2), 1}, {1, 2}})},
{{{2, -1}, {-1, 0}}}));

cone = defineHRep(3);
cone.addInequality({7, 1, 6, 0});
cone.addInequality({9, 1, 7, 0});
cone.addInequality({8, -1, 1, 0});

vertex = makeFracMatrix(3, 2, {{5, 2}, {6, 2}, {7, 1}});

gf = unimodularConeGeneratingFunction(vertex, 1, cone);

EXPECT_EQ_REPR_GENERATINGFUNCTION(
gf,
GeneratingFunction(
1, {1}, {makeFracMatrix(2, 3, {{-83, -100, -41}, {-22, -27, -15}})},
{{{8, 47, -17}, {-7, -41, 15}, {1, 5, -2}}}));
}