diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h index 8156b7b71f3a20..fd2dca6f99b984 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h @@ -498,9 +498,17 @@ class IntegerPolyhedron { /// Return the index at which the specified kind of id starts. unsigned getIdKindOffset(IdKind kind) const; + /// Return the index at which the specified kind of id ends. + unsigned getIdKindEnd(IdKind kind) const; + /// Get the number of ids of the specified kind. unsigned getNumIdKind(IdKind kind) const; + /// Get the number of elements of the specified kind in the range + /// [idStart, idLimit). + unsigned getIdKindOverlap(IdKind kind, unsigned idStart, + unsigned idLimit) const; + /// Removes identifiers in the column range [idStart, idLimit), and copies any /// remaining valid data into place, updates member variables, and resizes /// arrays as needed. diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp index 578f8547297071..c1c228b5c49924 100644 --- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp @@ -74,7 +74,7 @@ bool IntegerPolyhedron::isSubsetOf(const IntegerPolyhedron &other) const { Optional> IntegerPolyhedron::getRationalLexMin() const { - assert(numSymbols == 0 && "Symbols are not supported!"); + assert(getNumSymbolIds() == 0 && "Symbols are not supported!"); Optional> maybeLexMin = LexSimplex(*this).getRationalLexMin(); @@ -172,32 +172,21 @@ void IntegerPolyhedron::removeIdRange(unsigned idStart, unsigned idLimit) { return; // We are going to be removing one or more identifiers from the range. - assert(idStart < numIds && "invalid idStart position"); + assert(idStart < getNumIds() && "invalid idStart position"); - // TODO: Make 'removeIdRange' a lambda called from here. // Remove eliminated identifiers from the constraints.. equalities.removeColumns(idStart, idLimit - idStart); inequalities.removeColumns(idStart, idLimit - idStart); // Update members numDims, numSymbols and numIds. - unsigned numDimsEliminated = 0; - unsigned numLocalsEliminated = 0; - unsigned numColsEliminated = idLimit - idStart; - if (idStart < numDims) { - numDimsEliminated = std::min(numDims, idLimit) - idStart; - } - // Check how many local id's were removed. Note that our identifier order is - // [dims, symbols, locals]. Local id start at position numDims + numSymbols. - if (idLimit > numDims + numSymbols) { - numLocalsEliminated = std::min( - idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); - } + unsigned numDimsEliminated = + getIdKindOverlap(IdKind::Dimension, idStart, idLimit); unsigned numSymbolsEliminated = - numColsEliminated - numDimsEliminated - numLocalsEliminated; + getIdKindOverlap(IdKind::Symbol, idStart, idLimit); numDims -= numDimsEliminated; numSymbols -= numSymbolsEliminated; - numIds = numIds - numColsEliminated; + numIds -= (idLimit - idStart); } void IntegerPolyhedron::removeEquality(unsigned pos) { @@ -243,6 +232,10 @@ unsigned IntegerPolyhedron::getIdKindOffset(IdKind kind) const { llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); } +unsigned IntegerPolyhedron::getIdKindEnd(IdKind kind) const { + return getIdKindOffset(kind) + getNumIdKind(kind); +} + unsigned IntegerPolyhedron::getNumIdKind(IdKind kind) const { if (kind == IdKind::Dimension) return getNumDimIds(); @@ -253,6 +246,21 @@ unsigned IntegerPolyhedron::getNumIdKind(IdKind kind) const { llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!"); } +unsigned IntegerPolyhedron::getIdKindOverlap(IdKind kind, unsigned idStart, + unsigned idLimit) const { + unsigned idRangeStart = getIdKindOffset(kind); + unsigned idRangeEnd = getIdKindEnd(kind); + + // Compute number of elements in intersection of the ranges [idStart, idLimit) + // and [idRangeStart, idRangeEnd). + unsigned overlapStart = std::max(idStart, idRangeStart); + unsigned overlapEnd = std::min(idLimit, idRangeEnd); + + if (overlapStart > overlapEnd) + return 0; + return overlapEnd - overlapStart; +} + void IntegerPolyhedron::clearConstraints() { equalities.resizeVertically(0); inequalities.resizeVertically(0); @@ -319,7 +327,8 @@ bool IntegerPolyhedron::hasConsistentState() const { return false; // Catches errors where numDims, numSymbols, numIds aren't consistent. - if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) + if (getNumDimIds() > getNumIds() || getNumSymbolIds() > getNumIds() || + getNumDimAndSymbolIds() > getNumIds()) return false; return true; @@ -911,7 +920,7 @@ void IntegerPolyhedron::gcdTightenInequalities() { unsigned IntegerPolyhedron::gaussianEliminateIds(unsigned posStart, unsigned posLimit) { // Return if identifier positions to eliminate are out of range. - assert(posLimit <= numIds); + assert(posLimit <= getNumIds()); assert(hasConsistentState()); if (posStart >= posLimit) @@ -1253,7 +1262,7 @@ void IntegerPolyhedron::addLocalFloorDiv(ArrayRef dividend, } void IntegerPolyhedron::setDimSymbolSeparation(unsigned newSymbolCount) { - assert(newSymbolCount <= numDims + numSymbols && + assert(newSymbolCount <= getNumDimAndSymbolIds() && "invalid separation position"); numDims = numDims + numSymbols - newSymbolCount; numSymbols = newSymbolCount; @@ -1924,7 +1933,7 @@ static void getCommonConstraints(const IntegerPolyhedron &a, // lower bounds and the max of the upper bounds along each of the dimensions. LogicalResult IntegerPolyhedron::unionBoundingBox(const IntegerPolyhedron &otherCst) { - assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); assert(getNumLocalIds() == 0 && "local ids not supported yet here"); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index 8628cd86e08d30..cc9e4970721830 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -188,7 +188,7 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) // Construct from an IntegerSet. FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set) : FlatAffineConstraints(set) { - values.resize(numIds, None); + values.resize(getNumIds(), None); } // Construct a hyperrectangular constraint set from ValueRanges that represent @@ -1212,11 +1212,15 @@ FlatAffineValueConstraints::computeAlignedMap(AffineMap map, SmallVector *newSymsPtr = nullptr; #endif // NDEBUG - dims.reserve(numDims); - syms.reserve(numSymbols); - for (unsigned i = 0; i < numDims; ++i) + dims.reserve(getNumDimIds()); + syms.reserve(getNumSymbolIds()); + for (unsigned i = getIdKindOffset(IdKind::Dimension), + e = getIdKindEnd(IdKind::Dimension); + i < e; ++i) dims.push_back(values[i] ? *values[i] : Value()); - for (unsigned i = numDims, e = numDims + numSymbols; i < e; ++i) + for (unsigned i = getIdKindOffset(IdKind::Symbol), + e = getIdKindEnd(IdKind::Symbol); + i < e; ++i) syms.push_back(values[i] ? *values[i] : Value()); AffineMap alignedMap = @@ -1371,13 +1375,13 @@ void FlatAffineValueConstraints::clearAndCopyFrom( *static_cast(this) = other; values.clear(); - values.resize(numIds, None); + values.resize(getNumIds(), None); } void FlatAffineValueConstraints::fourierMotzkinEliminate( unsigned pos, bool darkShadow, bool *isResultIntegerExact) { SmallVector, 8> newVals; - newVals.reserve(numIds - 1); + newVals.reserve(getNumIds() - 1); newVals.append(values.begin(), values.begin() + pos); newVals.append(values.begin() + pos + 1, values.end()); // Note: Base implementation discards all associated Values. @@ -1397,7 +1401,7 @@ void FlatAffineValueConstraints::projectOut(Value val) { LogicalResult FlatAffineValueConstraints::unionBoundingBox( const FlatAffineValueConstraints &otherCst) { - assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch"); assert(otherCst.getMaybeValues() .slice(0, getNumDimIds()) .equals(getMaybeValues().slice(0, getNumDimIds())) && @@ -1408,7 +1412,7 @@ LogicalResult FlatAffineValueConstraints::unionBoundingBox( // Align `other` to this. if (!areIdsAligned(*this, otherCst)) { FlatAffineValueConstraints otherCopy(otherCst); - mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy); + mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy); return FlatAffineConstraints::unionBoundingBox(otherCopy); }