Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Interfaces] Variable abstraction for ValueBoundsOpInterface #87980

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Apr 8, 2024

This commit generalizes and cleans up the ValueBoundsConstraintSet API. The API used to provide function overloads for comparing/computing bounds of:

  • index-typed SSA value
  • dimension of shaped value
  • affine map + operands

This commit removes all overloads. There is now a single entry point for each compare variant and each computeBound variant. These functions now take a Variable, which is internally represented as an affine map and map operands.

This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

This commit generalizes and cleans up the ValueBoundsConstraintSet API. The API used to provide function overloads for comparing/computing bounds of:

  • index-typed SSA value
  • dimension of shaped value
  • affine map + operands

This commit removes all overloads. There is now a single entry point for each compare variant and each computeBound variant. These functions now take a Variable, which is internally represented as an affine map and map operands.

This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.

WIP until I added a test case for computeBounds(AffineMap).


Patch is 47.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87980.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+61-56)
  • (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+69)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+2-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+2-3)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+2-2)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+149-188)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+3-3)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(Variable var1, Variable var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d1..82a9fb0d490882f 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7c..6c59df91e8af781 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5a..411fc117a4d9f5d 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{falseValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{trueValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a2..f87f3d6350c0221 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190ed..5bb7d83bf1e3f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type,
+          ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+          closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489897..518d2e138c02a97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049d..71eb59d40836c1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9f..1f06318cbd60e04 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a3..17a1c016ea16d5a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the y...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Matthias Springer (matthias-springer)

Changes

This commit generalizes and cleans up the ValueBoundsConstraintSet API. The API used to provide function overloads for comparing/computing bounds of:

  • index-typed SSA value
  • dimension of shaped value
  • affine map + operands

This commit removes all overloads. There is now a single entry point for each compare variant and each computeBound variant. These functions now take a Variable, which is internally represented as an affine map and map operands.

This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.

WIP until I added a test case for computeBounds(AffineMap).


Patch is 47.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87980.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+61-56)
  • (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+69)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+2-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+2-3)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+2-2)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+149-188)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+3-3)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(Variable var1, Variable var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d1..82a9fb0d490882f 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7c..6c59df91e8af781 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5a..411fc117a4d9f5d 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{falseValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{trueValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a2..f87f3d6350c0221 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190ed..5bb7d83bf1e3f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type,
+          ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+          closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489897..518d2e138c02a97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049d..71eb59d40836c1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9f..1f06318cbd60e04 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a3..17a1c016ea16d5a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the y...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

This commit generalizes and cleans up the ValueBoundsConstraintSet API. The API used to provide function overloads for comparing/computing bounds of:

  • index-typed SSA value
  • dimension of shaped value
  • affine map + operands

This commit removes all overloads. There is now a single entry point for each compare variant and each computeBound variant. These functions now take a Variable, which is internally represented as an affine map and map operands.

This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.

WIP until I added a test case for computeBounds(AffineMap).


Patch is 47.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87980.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+61-56)
  • (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+69)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+2-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+2-3)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+2-2)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+149-188)
  • (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+3-3)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(Variable var1, Variable var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d1..82a9fb0d490882f 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7c..6c59df91e8af781 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5a..411fc117a4d9f5d 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{falseValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{trueValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a2..f87f3d6350c0221 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190ed..5bb7d83bf1e3f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type,
+          ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+          closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489897..518d2e138c02a97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049d..71eb59d40836c1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9f..1f06318cbd60e04 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a3..17a1c016ea16d5a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the y...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_typo branch from 4a019ca to d8c3fd6 Compare April 8, 2024 13:52
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_variable branch from ed12ff5 to 183ab14 Compare April 8, 2024 13:53
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor, I think easier to read

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp Outdated Show resolved Hide resolved
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_typo branch from d8c3fd6 to 0a7a53e Compare April 11, 2024 06:25
Base automatically changed from users/matthias-springer/value_bounds_typo to main April 11, 2024 06:27
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_variable branch 2 times, most recently from 79734b8 to 8ba7e9e Compare April 15, 2024 13:31
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/value_bounds_test_op April 15, 2024 13:31
@matthias-springer matthias-springer changed the title [mlir][Interfaces][WIP] Variable abstraction for ValueBoundsOpInterface [mlir][Interfaces] Variable abstraction for ValueBoundsOpInterface Apr 15, 2024
Base automatically changed from users/matthias-springer/value_bounds_test_op to main April 15, 2024 16:14
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_variable branch from 8ba7e9e to 3876dc9 Compare April 16, 2024 08:30
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_variable branch from 3876dc9 to afb0d75 Compare April 16, 2024 08:32
@matthias-springer matthias-springer merged commit 40dd3aa into main Apr 16, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_variable branch April 16, 2024 08:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants