Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Presburger] Implement computation of generating function for unimodular cones #77235

Merged
merged 11 commits into from Jan 10, 2024
6 changes: 6 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Barvinok.h
Expand Up @@ -24,6 +24,7 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_BARVINOK_H
#define MLIR_ANALYSIS_PRESBURGER_BARVINOK_H

#include "mlir/Analysis/Presburger/GeneratingFunction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h"
#include <optional>
Expand Down Expand Up @@ -77,6 +78,11 @@ ConeV getDual(ConeH cone);
/// The returned cone is pointed at the origin.
ConeH getDual(ConeV cone);

/// Compute the generating function for a unimodular cone.
/// It assert-fails if the input cone is not unimodular.
GeneratingFunction unimodularConeGeneratingFunction(ParamPoint vertex, int sign,
ConeH cone);

} // namespace detail
} // namespace presburger
} // namespace mlir
Expand Down
68 changes: 68 additions & 0 deletions mlir/lib/Analysis/Presburger/Barvinok.cpp
Expand Up @@ -63,3 +63,71 @@ MPInt mlir::presburger::detail::getIndex(ConeV cone) {

return cone.determinant();
}

/// Compute the generating function for a unimodular cone.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
GeneratingFunction mlir::presburger::detail::unimodularConeGeneratingFunction(
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
ParamPoint vertex, int sign, ConeH cone) {
// `cone` is assumed to be unimodular.
assert(getIndex(getDual(cone)) == 1 && "input cone is not unimodular!");

unsigned numVar = cone.getNumVars();
unsigned numIneq = cone.getNumInequalities();

// Thus its ray matrix, U, is the inverse of the
// transpose of its inequality matrix, `cone`.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
FracMatrix transp(numVar, numIneq);
for (unsigned i = 0; i < numVar; i++)
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
for (unsigned j = 0; j < numIneq; j++)
transp(j, i) = Fraction(cone.atIneq(i, j), 1);

FracMatrix generators(numVar, numIneq);
transp.determinant(&generators); // This is the U-matrix.

// The denominators of the generating function
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// are given by the generators of the cone, i.e.,
// the rows of the matrix U.
std::vector<Point> denominator(numIneq);
ArrayRef<Fraction> row;
for (unsigned i = 0; i < numVar; i++) {
row = generators.getRow(i);
denominator[i] = Point(row);
}

// The vertex is v : [d, n+1].
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// We need to find affine functions of parameters λi(p)
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// such that v = Σ λi(p)*ui.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// The λi are given by the columns of Λ = v^T @ U^{-1} = v^T @ transp.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// Then the numerator will be Σ -floor(-λi(p))*u_i.
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// Thus we store the numerator as the affine function -Λ,
// since the generators are already stored in the denominator.
// Note that the outer -1 will have to be accounted for, as it is not stored.
// See end for an example.

unsigned numColumns = vertex.getNumColumns();
unsigned numRows = vertex.getNumRows();
ParamPoint numerator(numColumns, numRows);
SmallVector<Fraction> ithCol(numRows);
for (unsigned i = 0; i < numColumns; i++) {
for (unsigned j = 0; j < numRows; j++)
ithCol[j] = vertex(j, i);
numerator.setRow(i, transp.preMultiplyWithRow(ithCol));
numerator.negateRow(i);
}

return GeneratingFunction(numColumns - 1, SmallVector<int>(1, sign),
std::vector({numerator}),
std::vector({denominator}));

// Suppose the vertex is given by the matrix [ 2 2 0], with 2 params
// [-1 -1/2 1]
// and the cone has H-representation [0 -1] => U-matrix [ 2 -1]
// [-1 -2] [-1 0]
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
// Therefore Λ will be given by [ 1 0 ] and the negation of this will be
// stored as the numerator.
// [ 1/2 -1 ]
// [ -1 -2 ]

// Algebraically, the numerator exponent is
// [ -2 ⌊ -N - M/2 +1 ⌋ + 1 ⌊ 0 +M +2 ⌋ ] -> first COLUMN of U is [2, -1]
// [ 1 ⌊ -N - M/2 +1 ⌋ + 0 ⌊ 0 +M +2 ⌋ ] -> second COLUMN of U is [-1, 0]
Abhinav271828 marked this conversation as resolved.
Show resolved Hide resolved
}
36 changes: 36 additions & 0 deletions mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
Expand Up @@ -46,3 +46,39 @@ TEST(BarvinokTest, getIndex) {
4, 4, {{4, 2, 5, 1}, {4, 1, 3, 6}, {8, 2, 5, 6}, {5, 2, 5, 7}});
EXPECT_EQ(getIndex(cone), cone.determinant());
}

// The following cones and vertices are randomly generated
// (s.t. the cones are unimodular) and the generating functions
// are computed. We check that the results contain the correct
// matrices.
TEST(BarvinokTest, unimodularConeGeneratingFunction) {
ConeH cone = defineHRep(2);
cone.addInequality({0, -1, 0});
cone.addInequality({-1, -2, 0});

ParamPoint vertex =
makeFracMatrix(2, 3, {{2, 2, 0}, {-1, -Fraction(1, 2), 1}});

GeneratingFunction gf = unimodularConeGeneratingFunction(vertex, 1, cone);

EXPECT_EQ_REPR_GENERATINGFUNCTION(
gf, GeneratingFunction(
2, {1},
{makeFracMatrix(3, 2, {{-1, 0}, {-Fraction(1, 2), 1}, {1, 2}})},
{{{2, -1}, {-1, 0}}}));

cone = defineHRep(3);
cone.addInequality({7, 1, 6, 0});
cone.addInequality({9, 1, 7, 0});
cone.addInequality({8, -1, 1, 0});

vertex = makeFracMatrix(3, 2, {{5, 2}, {6, 2}, {7, 1}});

gf = unimodularConeGeneratingFunction(vertex, 1, cone);

EXPECT_EQ_REPR_GENERATINGFUNCTION(
gf,
GeneratingFunction(
1, {1}, {makeFracMatrix(2, 3, {{-83, -100, -41}, {-22, -27, -15}})},
{{{8, 47, -17}, {-7, -41, 15}, {1, 5, -2}}}));
}