From f5f592683f82a78ca53f999716f476f8030863e0 Mon Sep 17 00:00:00 2001 From: Groverkss Date: Sun, 24 Oct 2021 20:06:03 +0530 Subject: [PATCH] [MLIR] FlatAffineValueConstraints: Fix bug in mergeSymbolIds This patch fixes a bug in implementation `mergeSymbolIds` where symbol identifiers were not unique after merging them. Asserts for checking uniqueness before and after the merge are also added. The asserts checking uniqueness after the merge fail without the fix on existing test cases. Reviewed By: arjunp Differential Revision: https://reviews.llvm.org/D111958 --- mlir/include/mlir/Analysis/AffineStructures.h | 5 +- mlir/lib/Analysis/AffineStructures.cpp | 50 ++++++++++++++++--- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 2c2344145ffc6e..59424fb9619b2a 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -884,8 +884,9 @@ class FlatAffineValueConstraints : public FlatAffineConstraints { } /// Merge and align symbols of `this` and `other` such that both get union of - /// of symbols that are unique. Symbols with Value as `None` are considered - /// to be inequal to all other symbols. + /// of symbols that are unique. Symbols in `this` and `other` should be + /// unique. Symbols with Value as `None` are considered to be inequal to all + /// other symbols. void mergeSymbolIds(FlatAffineValueConstraints &other); protected: diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 74b1a8960094e3..085396b840b9a1 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -448,17 +448,46 @@ bool FlatAffineValueConstraints::areIdsAlignedWithOther( return areIdsAligned(*this, other); } -/// Checks if the SSA values associated with `cst`'s identifiers are unique. -static bool LLVM_ATTRIBUTE_UNUSED -areIdsUnique(const FlatAffineValueConstraints &cst) { +/// Checks if the SSA values associated with `cst`'s identifiers in range +/// [start, end) are unique. +static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( + const FlatAffineValueConstraints &cst, unsigned start, unsigned end) { + + assert(start <= cst.getNumIds() && "Start position out of bounds"); + assert(end <= cst.getNumIds() && "End position out of bounds"); + + if (start >= end) + return true; + SmallPtrSet uniqueIds; - for (auto val : cst.getMaybeValues()) { + ArrayRef> maybeValues = cst.getMaybeValues(); + for (Optional val : maybeValues) { if (val.hasValue() && !uniqueIds.insert(val.getValue()).second) return false; } return true; } +/// Checks if the SSA values associated with `cst`'s identifiers are unique. +static bool LLVM_ATTRIBUTE_UNUSED +areIdsUnique(const FlatAffineConstraints &cst) { + return areIdsUnique(cst, 0, cst.getNumIds()); +} + +/// Checks if the SSA values associated with `cst`'s identifiers of kind `kind` +/// are unique. +static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique( + const FlatAffineValueConstraints &cst, FlatAffineConstraints::IdKind kind) { + + if (kind == FlatAffineConstraints::IdKind::Dimension) + return areIdsUnique(cst, 0, cst.getNumDimIds()); + if (kind == FlatAffineConstraints::IdKind::Symbol) + return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds()); + if (kind == FlatAffineConstraints::IdKind::Local) + return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds()); + llvm_unreachable("Unexpected IdKind"); +} + /// Merge and align the identifiers of A and B starting at 'offset', so that /// both constraint systems get the union of the contained identifiers that is /// dimension-wise and symbol-wise unique; both constraint systems are updated @@ -592,10 +621,15 @@ static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) { } /// Merge and align symbols of `this` and `other` such that both get union of -/// of symbols that are unique. Symbols with Value as `None` are considered -/// to be inequal to all other symbols. +/// of symbols that are unique. Symbols in `this` and `other` should be +/// unique. Symbols with Value as `None` are considered to be inequal to all +/// other symbols. void FlatAffineValueConstraints::mergeSymbolIds( FlatAffineValueConstraints &other) { + + assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); + assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); + SmallVector aSymValues; getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues); @@ -606,7 +640,7 @@ void FlatAffineValueConstraints::mergeSymbolIds( // If the id is a symbol in `other`, then align it, otherwise assume that // it is a new symbol if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() && - loc < getNumDimAndSymbolIds()) + loc < other.getNumDimAndSymbolIds()) other.swapId(s, loc); else other.insertSymbolId(s - other.getNumDimIds(), aSymValue); @@ -621,6 +655,8 @@ void FlatAffineValueConstraints::mergeSymbolIds( assert(getNumSymbolIds() == other.getNumSymbolIds() && "expected same number of symbols"); + assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique"); + assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique"); } // Changes all symbol identifiers which are loop IVs to dim identifiers.