Skip to content

Commit

Permalink
[MLIR] FlatAffineValueConstraints: Fix bug in mergeSymbolIds
Browse files Browse the repository at this point in the history
This patch fixes a bug in implementation `mergeSymbolIds` where symbol
identifiers were not unique after merging them. Asserts for checking uniqueness
before and after the merge are also added. The asserts checking uniqueness
after the merge fail without the fix on existing test cases.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D111958
  • Loading branch information
Groverkss committed Oct 24, 2021
1 parent 2ae67c9 commit f5f5926
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
5 changes: 3 additions & 2 deletions mlir/include/mlir/Analysis/AffineStructures.h
Expand Up @@ -884,8 +884,9 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
}

/// Merge and align symbols of `this` and `other` such that both get union of
/// of symbols that are unique. Symbols with Value as `None` are considered
/// to be inequal to all other symbols.
/// of symbols that are unique. Symbols in `this` and `other` should be
/// unique. Symbols with Value as `None` are considered to be inequal to all
/// other symbols.
void mergeSymbolIds(FlatAffineValueConstraints &other);

protected:
Expand Down
50 changes: 43 additions & 7 deletions mlir/lib/Analysis/AffineStructures.cpp
Expand Up @@ -448,17 +448,46 @@ bool FlatAffineValueConstraints::areIdsAlignedWithOther(
return areIdsAligned(*this, other);
}

/// Checks if the SSA values associated with `cst`'s identifiers are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineValueConstraints &cst) {
/// Checks if the SSA values associated with `cst`'s identifiers in range
/// [start, end) are unique.
static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {

assert(start <= cst.getNumIds() && "Start position out of bounds");
assert(end <= cst.getNumIds() && "End position out of bounds");

if (start >= end)
return true;

SmallPtrSet<Value, 8> uniqueIds;
for (auto val : cst.getMaybeValues()) {
ArrayRef<Optional<Value>> maybeValues = cst.getMaybeValues();
for (Optional<Value> val : maybeValues) {
if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
return false;
}
return true;
}

/// Checks if the SSA values associated with `cst`'s identifiers are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints &cst) {
return areIdsUnique(cst, 0, cst.getNumIds());
}

/// Checks if the SSA values associated with `cst`'s identifiers of kind `kind`
/// are unique.
static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
const FlatAffineValueConstraints &cst, FlatAffineConstraints::IdKind kind) {

if (kind == FlatAffineConstraints::IdKind::Dimension)
return areIdsUnique(cst, 0, cst.getNumDimIds());
if (kind == FlatAffineConstraints::IdKind::Symbol)
return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds());
if (kind == FlatAffineConstraints::IdKind::Local)
return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds());
llvm_unreachable("Unexpected IdKind");
}

/// Merge and align the identifiers of A and B starting at 'offset', so that
/// both constraint systems get the union of the contained identifiers that is
/// dimension-wise and symbol-wise unique; both constraint systems are updated
Expand Down Expand Up @@ -592,10 +621,15 @@ static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
}

/// Merge and align symbols of `this` and `other` such that both get union of
/// of symbols that are unique. Symbols with Value as `None` are considered
/// to be inequal to all other symbols.
/// of symbols that are unique. Symbols in `this` and `other` should be
/// unique. Symbols with Value as `None` are considered to be inequal to all
/// other symbols.
void FlatAffineValueConstraints::mergeSymbolIds(
FlatAffineValueConstraints &other) {

assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");

SmallVector<Value, 4> aSymValues;
getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);

Expand All @@ -606,7 +640,7 @@ void FlatAffineValueConstraints::mergeSymbolIds(
// If the id is a symbol in `other`, then align it, otherwise assume that
// it is a new symbol
if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
loc < getNumDimAndSymbolIds())
loc < other.getNumDimAndSymbolIds())
other.swapId(s, loc);
else
other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
Expand All @@ -621,6 +655,8 @@ void FlatAffineValueConstraints::mergeSymbolIds(

assert(getNumSymbolIds() == other.getNumSymbolIds() &&
"expected same number of symbols");
assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
}

// Changes all symbol identifiers which are loop IVs to dim identifiers.
Expand Down

0 comments on commit f5f5926

Please sign in to comment.