diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index a02eeeb08d820..4db03487cbe2f 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -280,7 +280,7 @@ class SimplexBase { Unknown &unknownFromRow(unsigned row); /// Add a new row to the tableau and the associated data structures. The row - /// is initialized to zero. + /// is initialized to zero. Returns the index of the added row. unsigned addZeroRow(bool makeRestricted = false); /// Add a new row to the tableau and the associated data structures. @@ -316,17 +316,12 @@ class SimplexBase { /// Return the number of fixed columns, as described in the constructor above, /// this is the number of columns beyond those for the variables in var. unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; } + unsigned getNumRows() const { return tableau.getNumRows(); } + unsigned getNumColumns() const { return tableau.getNumColumns(); } /// Stores whether or not a big M column is present in the tableau. bool usingBigM; - /// The number of rows in the tableau. - unsigned nRow; - - /// The number of columns in the tableau, including the common denominator - /// and the constant column. - unsigned nCol; - /// The number of redundant rows in the tableau. These are the first /// nRedundant rows. unsigned nRedundant; diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 8f4c0e64d11bd..f7e25b39805ee 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -33,8 +33,8 @@ scaleAndAddForAssert(ArrayRef a, int64_t scale, ArrayRef b) { SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset, unsigned nSymbol) - : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar), - nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) { + : usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol), + tableau(0, getNumFixedCols() + nVar), empty(false) { assert(symbolOffset + nSymbol <= nVar); colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex); @@ -57,12 +57,12 @@ const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const { } const Simplex::Unknown &SimplexBase::unknownFromColumn(unsigned col) const { - assert(col < nCol && "Invalid column"); + assert(col < getNumColumns() && "Invalid column"); return unknownFromIndex(colUnknown[col]); } const Simplex::Unknown &SimplexBase::unknownFromRow(unsigned row) const { - assert(row < nRow && "Invalid row"); + assert(row < getNumRows() && "Invalid row"); return unknownFromIndex(rowUnknown[row]); } @@ -72,26 +72,24 @@ Simplex::Unknown &SimplexBase::unknownFromIndex(int index) { } Simplex::Unknown &SimplexBase::unknownFromColumn(unsigned col) { - assert(col < nCol && "Invalid column"); + assert(col < getNumColumns() && "Invalid column"); return unknownFromIndex(colUnknown[col]); } Simplex::Unknown &SimplexBase::unknownFromRow(unsigned row) { - assert(row < nRow && "Invalid row"); + assert(row < getNumRows() && "Invalid row"); return unknownFromIndex(rowUnknown[row]); } unsigned SimplexBase::addZeroRow(bool makeRestricted) { - ++nRow; // Resize the tableau to accommodate the extra row. - tableau.resizeVertically(nRow); - // TODO: consider eliminating nRow, as it stores redundant information. - assert(tableau.getNumRows() == nRow && "Inconsistent tableau size"); + unsigned newRow = tableau.appendExtraRow(); + assert(getNumRows() == getNumRows() && "Inconsistent tableau size"); rowUnknown.push_back(~con.size()); - con.emplace_back(Orientation::Row, makeRestricted, nRow - 1); + con.emplace_back(Orientation::Row, makeRestricted, newRow); undoLog.push_back(UndoLogEntry::RemoveLastConstraint); - tableau(nRow - 1, 0) = 1; - return con.size() - 1; + tableau(newRow, 0) = 1; + return newRow; } /// Add a new row to the tableau corresponding to the given constant term and @@ -100,9 +98,11 @@ unsigned SimplexBase::addZeroRow(bool makeRestricted) { unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { assert(coeffs.size() == var.size() + 1 && "Incorrect number of coefficients!"); + assert(var.size() + getNumFixedCols() == getNumColumns() && + "inconsistent column count!"); - addZeroRow(makeRestricted); - tableau(nRow - 1, 1) = coeffs.back(); + unsigned newRow = addZeroRow(makeRestricted); + tableau(newRow, 1) = coeffs.back(); if (usingBigM) { // When the lexicographic pivot rule is used, instead of the variables // @@ -123,7 +123,7 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { if (!var[i].isSymbol) bigMCoeff -= coeffs[i]; // The coefficient to the big M parameter is stored in column 2. - tableau(nRow - 1, 2) = bigMCoeff; + tableau(newRow, 2) = bigMCoeff; } // Process each given variable coefficient. @@ -136,7 +136,7 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { // If a variable is in column position at column col, then we just add the // coefficient for that variable (scaled by the common row denominator) to // the corresponding entry in the new row. - tableau(nRow - 1, pos) += coeffs[i] * tableau(nRow - 1, 0); + tableau(newRow, pos) += coeffs[i] * tableau(newRow, 0); continue; } @@ -144,16 +144,16 @@ unsigned SimplexBase::addRow(ArrayRef coeffs, bool makeRestricted) { // row, scaled by the coefficient for the variable, accounting for the two // rows potentially having different denominators. The new denominator is // the lcm of the two. - int64_t lcm = mlir::lcm(tableau(nRow - 1, 0), tableau(pos, 0)); - int64_t nRowCoeff = lcm / tableau(nRow - 1, 0); + int64_t lcm = mlir::lcm(tableau(newRow, 0), tableau(pos, 0)); + int64_t nRowCoeff = lcm / tableau(newRow, 0); int64_t idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0)); - tableau(nRow - 1, 0) = lcm; - for (unsigned col = 1; col < nCol; ++col) - tableau(nRow - 1, col) = - nRowCoeff * tableau(nRow - 1, col) + idxRowCoeff * tableau(pos, col); + tableau(newRow, 0) = lcm; + for (unsigned col = 1, e = getNumColumns(); col < e; ++col) + tableau(newRow, col) = + nRowCoeff * tableau(newRow, col) + idxRowCoeff * tableau(pos, col); } - tableau.normalizeRow(nRow - 1); + tableau.normalizeRow(newRow); // Push to undo log along with the index of the new constraint. return con.size() - 1; } @@ -256,13 +256,13 @@ MaybeOptimum> LexSimplex::findRationalLexMin() { /// so we immediately try to move it to a column. LogicalResult LexSimplexBase::addCut(unsigned row) { int64_t d = tableau(row, 0); - addZeroRow(/*makeRestricted=*/true); - tableau(nRow - 1, 0) = d; - tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d. - tableau(nRow - 1, 2) = 0; - for (unsigned col = 3 + nSymbol; col < nCol; ++col) - tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d. - return moveRowUnknownToColumn(nRow - 1); + unsigned cutRow = addZeroRow(/*makeRestricted=*/true); + tableau(cutRow, 0) = d; + tableau(cutRow, 1) = -mod(-tableau(row, 1), d); // -c%d. + tableau(cutRow, 2) = 0; + for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) + tableau(cutRow, col) = mod(tableau(row, col), d); // b_i%d. + return moveRowUnknownToColumn(cutRow); } Optional LexSimplex::maybeGetNonIntegralVarRow() const { @@ -340,7 +340,7 @@ SymbolicLexSimplex::getSymbolicSampleIneq(unsigned row) const { void LexSimplexBase::appendSymbol() { appendVariable(); - swapColumns(3 + nSymbol, nCol - 1); + swapColumns(3 + nSymbol, getNumColumns() - 1); var.back().isSymbol = true; nSymbol++; } @@ -414,18 +414,18 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { appendSymbol(); // Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0. - addZeroRow(/*makeRestricted=*/true); - tableau(nRow - 1, 0) = d; - tableau(nRow - 1, 2) = 0; + unsigned cutRow = addZeroRow(/*makeRestricted=*/true); + tableau(cutRow, 0) = d; + tableau(cutRow, 2) = 0; - tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d). + tableau(cutRow, 1) = -mod(-tableau(row, 1), d); // -(-c%d). for (unsigned col = 3; col < 3 + nSymbol - 1; ++col) - tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i. - tableau(nRow - 1, 3 + nSymbol - 1) = d; // q*d. + tableau(cutRow, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i. + tableau(cutRow, 3 + nSymbol - 1) = d; // q*d. - for (unsigned col = 3 + nSymbol; col < nCol; ++col) - tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i. - return moveRowUnknownToColumn(nRow - 1); + for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) + tableau(cutRow, col) = mod(tableau(row, col), d); // (b_i%d)y_i. + return moveRowUnknownToColumn(cutRow); } void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { @@ -466,11 +466,11 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const { Optional SymbolicLexSimplex::maybeGetAlwaysViolatedRow() { // First look for rows that are clearly violated just from the big M // coefficient, without needing to perform any simplex queries on the domain. - for (unsigned row = 0; row < nRow; ++row) + for (unsigned row = 0, e = getNumRows(); row < e; ++row) if (tableau(row, 2) < 0) return row; - for (unsigned row = 0; row < nRow; ++row) { + for (unsigned row = 0, e = getNumRows(); row < e; ++row) { if (tableau(row, 2) > 0) continue; if (domainSimplex.isSeparateInequality(getSymbolicSampleIneq(row))) { @@ -541,9 +541,9 @@ SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { continue; } - unsigned splitRow; SmallVector symbolicSample; - for (splitRow = 0; splitRow < nRow; ++splitRow) { + unsigned splitRow = 0; + for (unsigned e = getNumRows(); splitRow < e; ++splitRow) { if (tableau(splitRow, 2) > 0) continue; assert(tableau(splitRow, 2) == 0 && @@ -561,7 +561,7 @@ SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() { break; } - if (splitRow < nRow) { + if (splitRow < getNumRows()) { unsigned domainSnapshot = domainSimplex.getSnapshot(); IntegerRelation::CountsSnapshot domainPolyCounts = domainPoly.getCounts(); @@ -658,7 +658,7 @@ bool LexSimplex::rowIsViolated(unsigned row) const { } Optional LexSimplex::maybeGetViolatedRow() const { - for (unsigned row = 0; row < nRow; ++row) + for (unsigned row = 0, e = getNumRows(); row < e; ++row) if (rowIsViolated(row)) return row; return {}; @@ -740,7 +740,7 @@ LogicalResult LexSimplex::restoreRationalConsistency() { // minimizes the change in sample value. LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { Optional maybeColumn; - for (unsigned col = 3 + nSymbol; col < nCol; ++col) { + for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) { if (tableau(row, col) <= 0) continue; maybeColumn = @@ -850,7 +850,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, Optional Simplex::findPivot(int row, Direction direction) const { Optional col; - for (unsigned j = 2; j < nCol; ++j) { + for (unsigned j = 2, e = getNumColumns(); j < e; ++j) { int64_t elem = tableau(row, j); if (elem == 0) continue; @@ -925,7 +925,7 @@ void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { tableau(pivotRow, 0) = -tableau(pivotRow, 0); tableau(pivotRow, pivotCol) = -tableau(pivotRow, pivotCol); } else { - for (unsigned col = 1; col < nCol; ++col) { + for (unsigned col = 1, e = getNumColumns(); col < e; ++col) { if (col == pivotCol) continue; tableau(pivotRow, col) = -tableau(pivotRow, col); @@ -933,18 +933,18 @@ void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { } tableau.normalizeRow(pivotRow); - for (unsigned row = 0; row < nRow; ++row) { + for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) { if (row == pivotRow) continue; if (tableau(row, pivotCol) == 0) // Nothing to do. continue; tableau(row, 0) *= tableau(pivotRow, 0); - for (unsigned j = 1; j < nCol; ++j) { - if (j == pivotCol) + for (unsigned col = 1, numCols = getNumColumns(); col < numCols; ++col) { + if (col == pivotCol) continue; // Add rather than subtract because the pivot row has been negated. - tableau(row, j) = tableau(row, j) * tableau(pivotRow, 0) + - tableau(row, pivotCol) * tableau(pivotRow, j); + tableau(row, col) = tableau(row, col) * tableau(pivotRow, 0) + + tableau(row, pivotCol) * tableau(pivotRow, col); } tableau(row, pivotCol) *= tableau(pivotRow, pivotCol); tableau.normalizeRow(row); @@ -1001,7 +1001,7 @@ Optional Simplex::findPivotRow(Optional skipRow, // reality, these are always initialized when that line is reached since these // are set whenever retRow is set. int64_t retElem = 0, retConst = 0; - for (unsigned row = nRedundant; row < nRow; ++row) { + for (unsigned row = nRedundant, e = getNumRows(); row < e; ++row) { if (skipRow && row == *skipRow) continue; int64_t elem = tableau(row, col); @@ -1043,7 +1043,8 @@ void SimplexBase::swapRows(unsigned i, unsigned j) { } void SimplexBase::swapColumns(unsigned i, unsigned j) { - assert(i < nCol && j < nCol && "Invalid columns provided!"); + assert(i < getNumColumns() && j < getNumColumns() && + "Invalid columns provided!"); if (i == j) return; tableau.swapColumns(i, j); @@ -1115,12 +1116,11 @@ void SimplexBase::removeLastConstraintRowOrientation() { // Move this unknown to the last row and remove the last row from the // tableau. - swapRows(con.back().pos, nRow - 1); + swapRows(con.back().pos, getNumRows() - 1); // It is not strictly necessary to shrink the tableau, but for now we - // maintain the invariant that the tableau has exactly nRow rows. - tableau.resizeVertically(nRow - 1); - nRow--; - assert(tableau.getNumRows() == nRow && "inconsistent tableau size!"); + // maintain the invariant that the tableau has exactly getNumRows() + // rows. + tableau.resizeVertically(getNumRows() - 1); rowUnknown.pop_back(); con.pop_back(); } @@ -1135,7 +1135,7 @@ void SimplexBase::removeLastConstraintRowOrientation() { // If we have a variable, then the column has zero coefficients for every row // iff no constraints have been added with a non-zero coefficient for this row. Optional SimplexBase::findAnyPivotRow(unsigned col) { - for (unsigned row = nRedundant; row < nRow; ++row) + for (unsigned row = nRedundant, e = getNumRows(); row < e; ++row) if (tableau(row, col) != 0) return row; return {}; @@ -1210,12 +1210,10 @@ void SimplexBase::undo(UndoLogEntry entry) { // Move this variable to the last column and remove the column from the // tableau. - swapColumns(var.back().pos, nCol - 1); - tableau.resizeHorizontally(nCol - 1); + swapColumns(var.back().pos, getNumColumns() - 1); + tableau.resizeHorizontally(getNumColumns() - 1); var.pop_back(); colUnknown.pop_back(); - nCol--; - assert(tableau.getNumColumns() == nCol && "inconsistent tableau size!"); } else if (entry == UndoLogEntry::UnmarkEmpty) { empty = false; } else if (entry == UndoLogEntry::UnmarkLastRedundant) { @@ -1230,7 +1228,8 @@ void SimplexBase::undo(UndoLogEntry entry) { Unknown &u = unknownFromIndex(index); if (u.orientation == Orientation::Column) continue; - for (unsigned col = getNumFixedCols(); col < nCol; col++) { + for (unsigned col = getNumFixedCols(), e = getNumColumns(); col < e; + col++) { assert(colUnknown[col] != nullIndex && "Column should not be a fixed column!"); if (std::find(basis.begin(), basis.end(), colUnknown[col]) != @@ -1286,13 +1285,11 @@ void SimplexBase::appendVariable(unsigned count) { var.reserve(var.size() + count); colUnknown.reserve(colUnknown.size() + count); for (unsigned i = 0; i < count; ++i) { - nCol++; var.emplace_back(Orientation::Column, /*restricted=*/false, - /*pos=*/nCol - 1); + /*pos=*/getNumColumns() + i); colUnknown.push_back(var.size() - 1); } - tableau.resizeHorizontally(nCol); - assert(tableau.getNumColumns() == nCol); + tableau.resizeHorizontally(getNumColumns() + count); undoLog.insert(undoLog.end(), count, UndoLogEntry::RemoveLastVariable); } @@ -1462,7 +1459,7 @@ Simplex Simplex::makeProduct(const Simplex &a, const Simplex &b) { unsigned numCon = a.getNumConstraints() + b.getNumConstraints(); Simplex result(numVar); - result.tableau.resizeVertically(numCon); + result.tableau.reserveRows(numCon); result.empty = a.empty || b.empty; auto concat = [](ArrayRef v, ArrayRef w) { @@ -1481,39 +1478,39 @@ Simplex Simplex::makeProduct(const Simplex &a, const Simplex &b) { }; result.colUnknown.assign(2, nullIndex); - for (unsigned i = 2; i < a.nCol; ++i) { + for (unsigned i = 2, e = a.getNumColumns(); i < e; ++i) { result.colUnknown.push_back(a.colUnknown[i]); result.unknownFromIndex(result.colUnknown.back()).pos = result.colUnknown.size() - 1; } - for (unsigned i = 2; i < b.nCol; ++i) { + for (unsigned i = 2, e = b.getNumColumns(); i < e; ++i) { result.colUnknown.push_back(indexFromBIndex(b.colUnknown[i])); result.unknownFromIndex(result.colUnknown.back()).pos = result.colUnknown.size() - 1; } auto appendRowFromA = [&](unsigned row) { - for (unsigned col = 0; col < a.nCol; ++col) - result.tableau(result.nRow, col) = a.tableau(row, col); + unsigned resultRow = result.tableau.appendExtraRow(); + for (unsigned col = 0, e = a.getNumColumns(); col < e; ++col) + result.tableau(resultRow, col) = a.tableau(row, col); result.rowUnknown.push_back(a.rowUnknown[row]); result.unknownFromIndex(result.rowUnknown.back()).pos = result.rowUnknown.size() - 1; - result.nRow++; }; // Also fixes the corresponding entry in rowUnknown and var/con (as the case // may be). auto appendRowFromB = [&](unsigned row) { - result.tableau(result.nRow, 0) = b.tableau(row, 0); - result.tableau(result.nRow, 1) = b.tableau(row, 1); + unsigned resultRow = result.tableau.appendExtraRow(); + result.tableau(resultRow, 0) = b.tableau(row, 0); + result.tableau(resultRow, 1) = b.tableau(row, 1); - unsigned offset = a.nCol - 2; - for (unsigned col = 2; col < b.nCol; ++col) - result.tableau(result.nRow, offset + col) = b.tableau(row, col); + unsigned offset = a.getNumColumns() - 2; + for (unsigned col = 2, e = b.getNumColumns(); col < e; ++col) + result.tableau(resultRow, offset + col) = b.tableau(row, col); result.rowUnknown.push_back(indexFromBIndex(b.rowUnknown[row])); result.unknownFromIndex(result.rowUnknown.back()).pos = result.rowUnknown.size() - 1; - result.nRow++; }; result.nRedundant = a.nRedundant + b.nRedundant; @@ -1521,15 +1518,11 @@ Simplex Simplex::makeProduct(const Simplex &a, const Simplex &b) { appendRowFromA(row); for (unsigned row = 0; row < b.nRedundant; ++row) appendRowFromB(row); - for (unsigned row = a.nRedundant; row < a.nRow; ++row) + for (unsigned row = a.nRedundant, e = a.getNumRows(); row < e; ++row) appendRowFromA(row); - for (unsigned row = b.nRedundant; row < b.nRow; ++row) + for (unsigned row = b.nRedundant, e = b.getNumRows(); row < e; ++row) appendRowFromB(row); - assert(result.tableau.getNumRows() == result.nRow && - "inconsistent row size!"); - assert(result.tableau.getNumColumns() == result.nCol && - "inconsistent row size!"); return result; } @@ -2076,7 +2069,7 @@ Simplex::computeIntegerBounds(ArrayRef coeffs) { } void SimplexBase::print(raw_ostream &os) const { - os << "rows = " << nRow << ", columns = " << nCol << "\n"; + os << "rows = " << getNumRows() << ", columns = " << getNumColumns() << "\n"; if (empty) os << "Simplex marked empty!\n"; os << "var: "; @@ -2092,18 +2085,18 @@ void SimplexBase::print(raw_ostream &os) const { con[i].print(os); } os << '\n'; - for (unsigned row = 0; row < nRow; ++row) { + for (unsigned row = 0, e = getNumRows(); row < e; ++row) { if (row > 0) os << ", "; os << "r" << row << ": " << rowUnknown[row]; } os << '\n'; os << "c0: denom, c1: const"; - for (unsigned col = 2; col < nCol; ++col) + for (unsigned col = 2, e = getNumColumns(); col < e; ++col) os << ", c" << col << ": " << colUnknown[col]; os << '\n'; - for (unsigned row = 0; row < nRow; ++row) { - for (unsigned col = 0; col < nCol; ++col) + for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) { + for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col) os << tableau(row, col) << '\t'; os << '\n'; }