Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Interfaces] ValueBoundsOpInterface: Add API to compare values #86915

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 28, 2024

This commit adds a new public API to ValueBoundsOpInterface to compare values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT.

The new ValueBoundsOpInterface::compare API replaces and generalizes ValueBoundsOpInterface::areEqual. Not only does it provide additional comparison operators, it also works in cases where the difference between the two values/dims is non-constant. The previous implementation of areEqual used to compute a constant bound of val1 - val2 (check if it == 0 or != 0).

Note: This commit refactors, generalizes and adds a public API for value/dim comparison. The comparison functionality itself was introduced in #85895 and is already in use for analyzing scf.if.

In the long term, this improvement will allow for a more powerful analysis of subset ops. A future commit will update areOverlappingSlices to use the new comparison API. (areEquivalentSlices is already using the new API.) This will improve subset equivalence/disjointness checks with non-constant offsets/sizes/strides.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 28, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new public API to ValueBoundsOpInterface to compare values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT.

The new ValueBoundsOpInterface::compare API replaces and generalizes ValueBoundsOpInterface::areEqual. Not only does it provide additional comparison operators, it also works in cases where the difference between the two values/dims is non-constant. The previous implementation of areEqual used to compute a constant bound of val1 - val2.

Note: This commit refactors, generalizes and adds a public API for value/dim comparison. The comparison functionality itself was introduced in #85895 and is already in use for analyzing scf.if.

In the long term, this improvement will allow for a more powerful analysis of subset ops. A future commit will update areOverlappingSlices to use the new comparison API. (areEquivalentSlices is already using the new API.) This will improve subset equivalence/disjointness checks with non-constant offsets/sizes/strides.


Patch is 31.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86915.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+51-10)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+9-22)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+181-58)
  • (modified) mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir (+42-1)
  • (modified) mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (+12)
  • (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+8-8)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+66-13)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index f35432ca0136f3..d27081fad8c6c0 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
   /// Comparison operator for `ValueBoundsConstraintSet::compare`.
   enum ComparisonOperator { LT, LE, EQ, GT, GE };
 
-  /// Try to prove that, based on the current state of this constraint set
+  /// Populate constraints for lhs/rhs (until the stop condition is met). Then,
+  /// try to prove that, based on the current state of this constraint set
   /// (i.e., without analyzing additional IR or adding new constraints), the
   /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
   ///
@@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
-               Value rhs, std::optional<int64_t> rhsDim);
+  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                          ComparisonOperator cmp, OpFoldResult rhs,
+                          std::optional<int64_t> rhsDim);
+
+  /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
+  /// specified relation could not be proven. This could be because the
+  /// specified relation does in fact not hold or because there is not enough
+  /// information in the constraint set. In other words, if we do not know for
+  /// sure, this function returns "false".
+  ///
+  /// This function keeps traversing the backward slice of lhs/rhs until could
+  /// prove the relation or until it ran out of IR.
+  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                      ComparisonOperator cmp, OpFoldResult rhs,
+                      std::optional<int64_t> rhsDim);
+  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
+                      ComparisonOperator cmp, AffineMap rhs,
+                      ValueDimList rhsOperands);
+  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
+                      ComparisonOperator cmp, AffineMap rhs,
+                      ArrayRef<Value> rhsOperands);
 
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
   /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
   /// index-typed.
-  static FailureOr<bool> areEqual(Value value1, Value value2,
+  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
                                   std::optional<int64_t> dim1 = std::nullopt,
                                   std::optional<int64_t> dim2 = std::nullopt);
 
-  /// Compute whether the given values/attributes are equal. Return "failure" if
-  /// equality could not be determined.
-  ///
-  /// `ofr1`/`ofr2` must be of index type.
-  static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
-
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
   /// Return "failure" if unknown.
@@ -290,6 +304,20 @@ class ValueBoundsConstraintSet
 
   ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
+  /// Return "true" if, based on the current state of the constraint system,
+  /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
+  /// could not be proven. This could be because the specified relation does in
+  /// fact not hold or because there is not enough information in the constraint
+  /// set. In other words, if we do not know for sure, this function returns
+  /// "false".
+  ///
+  /// This function does not analyze any IR and does not populate any additional
+  /// constraints.
+  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                        ComparisonOperator cmp, OpFoldResult rhs,
+                        std::optional<int64_t> rhsDim);
+  bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+
   /// Given an affine map with a single result (and map operands), add a new
   /// column to the constraint set that represents the result of the map.
   /// Traverse additional IR starting from the map operands as needed (as long
@@ -311,6 +339,14 @@ class ValueBoundsConstraintSet
   /// value/dimension exists in the constraint set.
   int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
 
+  /// Return an affine expression that represents column `pos` in the constraint
+  /// set.
+  AffineExpr getPosExpr(int64_t pos);
+
+  /// Return "true" if the given value/dim is mapped (i.e., has a corresponding
+  /// column in the constraint system).
+  bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
+
   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
   /// "false", a dimension is added. The value/dimension is added to the
   /// worklist if `addToWorklist` is set.
@@ -330,6 +366,11 @@ class ValueBoundsConstraintSet
   /// dimensions but not for symbols.
   int64_t insert(bool isSymbol = true);
 
+  /// Insert the given affine map and its bound operands as a new column in the
+  /// constraint system. Return the position of the new column. Any operands
+  /// that were not analyzed yet are put on the worklist.
+  int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
 
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 72c5aaa2306783..087ffc438a830a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,20 +58,11 @@ struct ForOpInterface
     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
     Value initArg = forOp.getInitArgs()[iterArgIdx];
 
-    // Populate constraints for the yielded value.
-    cstr.populateConstraints(yieldedValue, dim);
-    // Populate constraints for the iter_arg. This is just to ensure that the
-    // iter_arg is mapped in the constraint set, which is a prerequisite for
-    // `compare`. It may lead to a recursive call to this function in case the
-    // iter_arg was not visited when the constraints for the yielded value were
-    // populated, but no additional work is done.
-    cstr.populateConstraints(iterArg, dim);
-
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
-    if (cstr.compare(yieldedValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
-                     dim)) {
+    if (cstr.populateAndCompare(
+            yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
+            iterArg, dim)) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
@@ -113,10 +104,6 @@ struct IfOpInterface
     Value thenValue = ifOp.thenYield().getResults()[resultNum];
     Value elseValue = ifOp.elseYield().getResults()[resultNum];
 
-    // Populate constraints for the yielded value (and all values on the
-    // backward slice, as long as the current stop condition is not satisfied).
-    cstr.populateConstraints(thenValue, dim);
-    cstr.populateConstraints(elseValue, dim);
     auto boundsBuilder = cstr.bound(value);
     if (dim)
       boundsBuilder[*dim];
@@ -125,9 +112,9 @@ struct IfOpInterface
     // If thenValue <= elseValue:
     // * result <= elseValue
     // * result >= thenValue
-    if (cstr.compare(thenValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     elseValue, dim)) {
+    if (cstr.populateAndCompare(
+            thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+            elseValue, dim)) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -139,9 +126,9 @@ struct IfOpInterface
     // If elseValue <= thenValue:
     // * result <= thenValue
     // * result >= elseValue
-    if (cstr.compare(elseValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     thenValue, dim)) {
+    if (cstr.populateAndCompare(
+            elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+            thenValue, dim)) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index dd98da9adc7d96..d7ffed14daccdd 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -212,6 +212,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
+                                         bool isSymbol) {
+  assert(map.getNumResults() == 1 && "expected affine map with one result");
+  int64_t pos = insert(/*isSymbol=*/false);
+
+  // Add map and operands to the constraint set. Dimensions are converted to
+  // symbols. All operands are added to the worklist (unless they were already
+  // processed).
+  auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+    return getExpr(v.first, v.second);
+  };
+  SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+  SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+  addBound(
+      presburger::BoundType::EQ, pos,
+      map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+  return pos;
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -227,6 +249,20 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
   return it->second;
 }
 
+AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
+  assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
+  return pos < cstr.getNumDimVars()
+             ? builder.getAffineDimExpr(pos)
+             : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+}
+
+bool ValueBoundsConstraintSet::isMapped(Value value,
+                                        std::optional<int64_t> dim) const {
+  auto it =
+      valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
+  return it != valueDimToPosition.end();
+}
+
 static Operation *getOwnerOfValue(Value value) {
   if (auto bbArg = dyn_cast<BlockArgument>(value))
     return bbArg.getOwner()->getParentOp();
@@ -563,27 +599,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value,
 
 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
                                                       ValueDimList operands) {
-  assert(map.getNumResults() == 1 && "expected affine map with one result");
-  int64_t pos = insert(/*isSymbol=*/false);
-
-  // Add map and operands to the constraint set. Dimensions are converted to
-  // symbols. All operands are added to the worklist (unless they were already
-  // processed).
-  auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
-    return getExpr(v.first, v.second);
-  };
-  SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
-      llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
-  SmallVector<AffineExpr> symReplacements = llvm::to_vector(
-      llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
-  addBound(
-      presburger::BoundType::EQ, pos,
-      map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
-
+  int64_t pos = insert(map, operands, /*isSymbol=*/false);
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
   processWorklist();
-
   return pos;
 }
 
@@ -603,9 +622,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                               {{value1, dim1}, {value2, dim2}});
 }
 
-bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
-                                       ComparisonOperator cmp, Value rhs,
-                                       std::optional<int64_t> rhsDim) {
+bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
+                                                std::optional<int64_t> lhsDim,
+                                                ComparisonOperator cmp,
+                                                OpFoldResult rhs,
+                                                std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    assertValidValueDim(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
   // This function returns "true" if "lhs CMP rhs" is proven to hold.
   //
   // Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -624,19 +652,61 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
 
   // EQ can be expressed as LE and GE.
   if (cmp == EQ)
-    return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
-           compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
+    return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
+           compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
 
   // Construct inequality. For the above example: lhs > rhs.
   // `IntegerRelation` inequalities are expressed in the "flattened" form and
   // with ">= 0". I.e., lhs - rhs - 1 >= 0.
+  SmallVector<int64_t> eq(cstr.getNumCols(), 0);
+  auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
+                     int64_t factor) {
+    if (auto constVal = ::getConstantIntValue(ofr)) {
+      eq[cstr.getNumCols() - 1] += *constVal * factor;
+    } else {
+      eq[getPos(cast<Value>(ofr), dim)] += factor;
+    }
+  };
+  if (cmp == LT || cmp == LE) {
+    addToEq(lhs, lhsDim, 1);
+    addToEq(rhs, rhsDim, -1);
+  } else if (cmp == GT || cmp == GE) {
+    addToEq(lhs, lhsDim, -1);
+    addToEq(rhs, rhsDim, 1);
+  } else {
+    llvm_unreachable("unsupported comparison operator");
+  }
+  if (cmp == LE || cmp == GE)
+    eq[cstr.getNumCols() - 1] -= 1;
+
+  // Add inequality to the constraint set and check if it made the constraint
+  // set empty.
+  int64_t ineqPos = cstr.getNumInequalities();
+  cstr.addInequality(eq);
+  bool isEmpty = cstr.isEmpty();
+  cstr.removeInequality(ineqPos);
+  return isEmpty;
+}
+
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+                                          ComparisonOperator cmp,
+                                          int64_t rhsPos) {
+  // This function returns "true" if "lhs CMP rhs" is proven to hold. For
+  // detailed documentation, see `compareValueDims`.
+
+  // EQ can be expressed as LE and GE.
+  if (cmp == EQ)
+    return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
+           comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+
+  // Construct inequality.
   SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
   if (cmp == LT || cmp == LE) {
-    ++eq[getPos(lhs, lhsDim)];
-    --eq[getPos(rhs, rhsDim)];
+    ++eq[lhsPos];
+    --eq[rhsPos];
   } else if (cmp == GT || cmp == GE) {
-    --eq[getPos(lhs, lhsDim)];
-    ++eq[getPos(rhs, rhsDim)];
+    --eq[lhsPos];
+    ++eq[rhsPos];
   } else {
     llvm_unreachable("unsupported comparison operator");
   }
@@ -652,40 +722,93 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
   return isEmpty;
 }
 
+bool ValueBoundsConstraintSet::populateAndCompare(
+    OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
+    OpFoldResult rhs, std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    assertValidValueDim(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    populateConstraints(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    populateConstraints(rhsVal, rhsDim);
+
+  return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
+                                       std::optional<int64_t> lhsDim,
+                                       ComparisonOperator cmp, OpFoldResult rhs,
+                                       std::optional<int64_t> rhsDim) {
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs are not mapped.
+    if (auto lhsVal = dyn_cast<Value>(lhs))
+      if (!cstr.isMapped(lhsVal, dim))
+        return false;
+    if (auto rhsVal = dyn_cast<Value>(rhs))
+      if (!cstr.isMapped(rhsVal, dim))
+        return false;
+    // Keep processing as long as the relation cannot be proven.
+    return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+  };
+
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
+                                       ComparisonOperator cmp, AffineMap rhs,
+                                       ValueDimList rhsOperands) {
+  int64_t lhsPos = -1, rhsPos = -1;
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs were not processed.
+    if (lhsPos >= cstr.positionToValueDim.size() ||
+        rhsPos >= cstr.positionToValueDim.size())
+      return false;
+    // Keep processing as long as the relation cannot be proven.
+    return cstr.comparePos(lhsPos, cmp, rhsPos);
+  };
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  lhsPos = cstr.insert(lhs, lhsOperands);
+  rhsPos = cstr.insert(rhs, rhsOperands);
+  return cstr.comparePos(lhsPos, cmp, rhsPos);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs,
+                                       ArrayRef<Value> lhsOperands,
+                                       ComparisonOperator cmp, AffineMap rhs,
+                                       ArrayRef<Value> rhsOperands) {
+  ValueDimList lhsValueDimOperands =
+      llvm::map_to_vector(lhsOperands, [](Value v) {
+        return std::make_pair(v, std::optional<int64_t>());
+      });
+  ValueDimList rhsValueDimOperands =
+      llvm::map_to_vector(rhsOperands, [](Value v) {
+        return std::make_pair(v, std::optional<int64_t>());
+      });
+  return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
+                                           rhsValueDimOperands);
+}
+
 FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
                                    std::optional<int64_t> dim1,
                                    std::optional<int64_t> dim2) {
-  // Subtract the two values/dimensions from each other. If the result is 0,
-  // both are equal.
-  FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
-  if (failed(delta))
-    return failure();
-  return *delta == 0;
-}
-
-FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
-                                                   OpFoldR...
[truncated]

///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we need a helper class of sorts here that groups the results/affinemap and dims/operands together. Just makes it more difficult to mix up + gives a bit more of a consistent lhs cmp rhs interface. The packing and unpacking should folded away during compilation. [just thinking out loud when reading this, not saying required change for this]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. We could add something like:

class Variable {
  Variable(OpFoldResult);  // asserts that Value/Attribute is index-typed
  Variable(Value);  // asserts that Value is index-typed
  Variable(Value, int64_t);  // asserts that Value is a shaped value
  Variable(OpFoldResult, std::optional<int64_t>);  // must be index-typed+nullopt or shaped value+non-nullopt
  Variable(AffineMap, mapOperands);

 private:
  OpFoldResult ofr;
  std::optional<int64_t> dim;
  AffineMap map;
  SmallVector<Variable> mapOperands;
};

I think then we could have a single entry point into each of computeBound (and variants), computeConstantBound, compare and areEqual. No more overloads needed. (It requires a bit of "flattening" if we allow Variable as mapOperands.)

I'm going to prepare a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New API added in #87980.

}
};
if (cmp == LT || cmp == LE) {
addToEq(lhs, lhsDim, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you write out the equations here being added? (I think it would make the +1 and -1 here be easier to read)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's written in line 660. (lhs - rhs - 1 >= 0, I'm using this example throughout the function.)

This particular line here is part of the code that builds the inequality. E.g., when adding a LT or LE inequality, the lhs must appear with a negative factor in the inequality. (Because inequalities are always in the form ... >= 0 in FlatLinearConstraints.)

addToEq(lhs, lhsDim, -1);
addToEq(rhs, rhsDim, 1);
} else {
llvm_unreachable("unsupported comparison operator");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't one be able to return false here and be conservatively correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this check above in the same function:

  // EQ can be expressed as LE and GE.
  if (cmp == EQ)
    ...

So we do have complete coverage of all comparison operators in here, and we should not be able to end up in this else branch.

op->emitRemark("true");
} else if (*cmpType != ValueBoundsConstraintSet::EQ &&
compare(invertComparisonOperator(*cmpType))) {
op->emitRemark("false");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left over debugging?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on purpose. I look for these diagnostics in the test cases (verify-diagnostics).

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_arith_select branch from 680d04d to 4ef0c78 Compare April 5, 2024 04:31
Base automatically changed from users/matthias-springer/value_bounds_arith_select to main April 5, 2024 04:39
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_compare branch 2 times, most recently from 62fe9bb to 772389c Compare April 8, 2024 11:21
Also use `compare` API for `areEqual` etc.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_compare branch from 772389c to 80949fe Compare April 8, 2024 13:51
@matthias-springer matthias-springer merged commit 297eca9 into main Apr 11, 2024
4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_compare branch April 11, 2024 06:23

// Construct inequality. For the above example: lhs > rhs.
// `IntegerRelation` inequalities are expressed in the "flattened" form and
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
Copy link
Contributor

@hanhanW hanhanW Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debugging an issue triggered by #87862 and down to here. It hit the below assertion. It looks like the size of eq should be getNumCols. Is it a bug fixed by this PR?

void IntegerRelation::addInequality(ArrayRef<MPInt> inEq) {
assert(inEq.size() == getNumCols());
unsigned row = inequalities.appendExtraRow();
for (unsigned i = 0, e = inEq.size(); i < e; ++i)
inequalities(row, i) = inEq[i];
}

EDIT: I verified that my issue is fixed by the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants