Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SCF] ValueBoundsConstraintSet: Support scf.if (branches) #85895

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds support for scf.if to ValueBoundsConstraintSet.

Example:

%0 = scf.if ... -> index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}

The following constraints hold for %0:

  • %0 >= min(%a, %b)
  • %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not supported by IntegerRelation. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:

  • %0 >= %a
  • %0 <= %b

This commit required a few minor changes to the ValueBoundsConstraintSet infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.

@llvmbot
Copy link

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for scf.if to ValueBoundsConstraintSet.

Example:

%0 = scf.if ... -&gt; index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}

The following constraints hold for %0:

  • %0 >= min(%a, %b)
  • %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not supported by IntegerRelation. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:

  • %0 >= %a
  • %0 <= %b

This commit required a few minor changes to the ValueBoundsConstraintSet infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.


Patch is 20.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85895.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+31-5)
  • (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+63)
  • (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+122-24)
  • (modified) mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (+117-2)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 28dadfb9ecf868..d11ed704680f61 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -198,6 +198,28 @@ class ValueBoundsConstraintSet {
                        std::optional<int64_t> dim1 = std::nullopt,
                        std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Traverse the IR starting from the given value/dim and add populate
+  /// constraints as long as the currently set stop condition holds. Also
+  /// processes all values/dims that are already on the worklist.
+  void populateConstraints(Value value, std::optional<int64_t> dim);
+
+  /// Comparison operator for `ValueBoundsConstraintSet::compare`.
+  enum ComparisonOperator { LT, LE, EQ, GT, GE };
+
+  /// Try to prove that, based on the current state of this constraint set
+  /// (i.e., without analyzing additional IR or adding new constraints), it can
+  /// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the
+  /// second given value/dim.
+  ///
+  /// Return "true" if the specified relation between the two values/dims was
+  /// proven to hold. Return "false" if the specified relation could not be
+  /// proven. This could be because the specified relation does in fact not hold
+  /// or because there is not enough information in the constraint set. In other
+  /// words, if we do not know for sure, this function returns "false".
+  bool compare(Value value1, std::optional<int64_t> dim1,
+               ComparisonOperator cmp, Value value2,
+               std::optional<int64_t> dim2);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
@@ -266,9 +288,9 @@ class ValueBoundsConstraintSet {
   ValueBoundsConstraintSet(MLIRContext *ctx);
 
   /// Iteratively process all elements on the worklist until an index-typed
-  /// value or shaped value meets `stopCondition`. Such values are not processed
-  /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  /// value or shaped value meets `currentStopCondition`. Such values are not
+  /// processed any further.
+  void processWorklist();
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
@@ -280,12 +302,13 @@ class ValueBoundsConstraintSet {
 
   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
   /// "false", a dimension is added. The value/dimension is added to the
-  /// worklist.
+  /// worklist if `addToWorklist` is set.
   ///
   /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
   /// cannot be multiplied. Furthermore, bounds can only be queried for
   /// dimensions but not for symbols.
-  int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
+  int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
+                 bool addToWorklist = true);
 
   /// Insert an anonymous column into the constraint set. The column is not
   /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
@@ -315,6 +338,9 @@ class ValueBoundsConstraintSet {
 
   /// Builder for constructing affine expressions.
   Builder builder;
+
+  /// The current stop condition function.
+  StopConditionFn currentStopCondition = nullptr;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24..bd9615b8eb5532 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -111,6 +111,68 @@ struct ForOpInterface
   }
 };
 
+struct IfOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    // Populate constraints for the yielded value (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt);
+    cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt);
+
+    // Compare yielded values.
+    // If thenValue <= elseValue:
+    // * result <= elseValue
+    // * result >= thenValue
+    if (cstr.compare(thenValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= thenValue;
+      cstr.bound(value) <= elseValue;
+    }
+    // If elseValue <= thenValue:
+    // * result <= thenValue
+    // * result >= elseValue
+    if (cstr.compare(elseValue, /*dim1=*/std::nullopt,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, /*dim2=*/std::nullopt)) {
+      cstr.bound(value) >= elseValue;
+      cstr.bound(value) <= thenValue;
+    }
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    // See `populateBoundsForIndexValue` for documentation.
+    auto ifOp = cast<IfOp>(op);
+    unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+    Value thenValue = ifOp.thenYield().getResults()[resultNum];
+    Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+    cstr.populateConstraints(thenValue, dim);
+    cstr.populateConstraints(elseValue, dim);
+
+    if (cstr.compare(thenValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     elseValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim);
+    }
+    if (cstr.compare(elseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     thenValue, dim)) {
+      cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim);
+      cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim);
+    }
+  }
+};
+
 } // namespace
 } // namespace scf
 } // namespace mlir
@@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
     scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
+    scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
   });
 }
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 85abc2df894797..b30b34bad075b0 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -105,25 +105,43 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
   assertValidValueDim(value, dim);
 #endif // NDEBUG
 
+  auto getPosExpr = [&](int64_t pos) {
+    assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() &&
+           "invalid position");
+    return pos < cstr.getNumDimVars()
+               ? builder.getAffineDimExpr(pos)
+               : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+  };
+
+  // If the value/dim is already mapped, return the corresponding expression
+  // directly.
+  ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
+  if (valueDimToPosition.contains(valueDim))
+    return getPosExpr(getPos(value, dim));
+
   auto shapedType = dyn_cast<ShapedType>(value.getType());
   if (shapedType) {
-    // Static dimension: return constant directly.
-    if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
-      return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
+    // Static dimension: add EQ bound and return expression without pushing the
+    // dim onto the worklist.
+    if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim)) {
+      int64_t pos =
+          insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
+      bound(value)[*dim] == shapedType.getDimSize(*dim);
+      return getPosExpr(pos);
+    }
   } else {
-    // Constant index value: return directly.
-    if (auto constInt = ::getConstantIntValue(value))
-      return builder.getAffineConstantExpr(*constInt);
+    // Constant index value: add EQ bound and return expression without pushing
+    // the value onto the worklist.
+    if (auto constInt = ::getConstantIntValue(value)) {
+      int64_t pos =
+          insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
+      bound(value) == *constInt;
+      return getPosExpr(pos);
+    }
   }
 
-  // Dynamic value: add to constraint set.
-  ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  if (!valueDimToPosition.contains(valueDim))
-    (void)insert(value, dim);
-  int64_t pos = getPos(value, dim);
-  return pos < cstr.getNumDimVars()
-             ? builder.getAffineDimExpr(pos)
-             : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+  // Dynamic value/dim: add to worklist.
+  return getPosExpr(insert(value, dim, /*isSymbol=*/true));
 }
 
 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -140,7 +158,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
 
 int64_t ValueBoundsConstraintSet::insert(Value value,
                                          std::optional<int64_t> dim,
-                                         bool isSymbol) {
+                                         bool isSymbol, bool addToWorklist) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
 #endif // NDEBUG
@@ -155,7 +173,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
     if (positionToValueDim[i].has_value())
       valueDimToPosition[*positionToValueDim[i]] = i;
 
-  worklist.push(pos);
+  if (addToWorklist) {
+    LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
+                            << " (dim: " << dim.value_or(kIndexValue) << ")\n");
+    worklist.push(pos);
+  }
+
   return pos;
 }
 
@@ -191,7 +214,8 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+  LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -212,13 +236,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
 
     // Do not process any further if the stop condition is met.
     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
-    if (stopCondition(value, maybeDim))
+    if (currentStopCondition(value, maybeDim)) {
+      LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+                              << " (dim: " << maybeDim << ")\n");
       continue;
+    }
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
     // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+    LLVM_DEBUG(llvm::dbgs()
+               << "Query value bounds for: " << value
+               << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
     if (valueBoundsOp) {
       if (dim == kIndexValue) {
         valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -226,6 +256,9 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
         valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
       }
       continue;
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "--> ValueBoundsOpInterface not implemented\n");
     }
 
     // If the op does not implement `ValueBoundsOpInterface`, check if it
@@ -301,7 +334,8 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
   ValueBoundsConstraintSet cstr(value.getContext());
   int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
-  cstr.processWorklist(stopCondition);
+  cstr.currentStopCondition = stopCondition;
+  cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
@@ -494,14 +528,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
   if (stopCondition) {
-    cstr.processWorklist(stopCondition);
+    cstr.currentStopCondition = stopCondition;
+    cstr.processWorklist();
   } else {
     // No stop condition specified: Keep adding constraints until a bound could
     // be computed.
-    cstr.processWorklist(
-        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        });
+    auto stopCondFn = [&](Value v, std::optional<int64_t> dim) {
+      return cstr.cstr.getConstantBound64(type, pos).has_value();
+    };
+    cstr.currentStopCondition = stopCondFn;
+    cstr.processWorklist();
   }
 
   // Compute constant bound for `valueDim`.
@@ -538,6 +574,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                               {{value1, dim1}, {value2, dim2}});
 }
 
+void ValueBoundsConstraintSet::populateConstraints(Value value,
+                                                   std::optional<int64_t> dim) {
+  // `getExpr` pushes the value/dim onto the worklist (unless it was already
+  // analyzed).
+  (void)getExpr(value, dim);
+  // Process all values/dims on the worklist. This may traverse and analyze
+  // additional IR, depending the current stop function.
+  processWorklist();
+}
+
+bool ValueBoundsConstraintSet::compare(Value value1,
+                                       std::optional<int64_t> dim1,
+                                       ComparisonOperator cmp, Value value2,
+                                       std::optional<int64_t> dim2) {
+  // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to
+  // hold.
+  //
+  // Example for ComparisonOperator::LE and index-typed values: We would like to
+  // prove that value1 <= value2. Proof by contradiction: add the inverse
+  // relation (value1 > value2) to the constraint set and check if the resulting
+  // constraint set is "empty" (i.e. has no solution). In that case,
+  // value1 > value2 must be incorrect and we can deduce that value1 <= value2
+  // holds.
+
+  // We cannot use prove anything if the constraint set is already empty.
+  if (cstr.isEmpty()) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "cannot compare value/dims: constraint system is already empty");
+    return false;
+  }
+
+  // EQ can be expressed as LE and GE.
+  if (cmp == EQ)
+    return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) &&
+           compare(value1, dim1, ComparisonOperator::GE, value2, dim2);
+
+  // Construct inequality. For the above example: value1 > value2.
+  // `IntegerRelation` inequalities are expressed in the "flattened" form and
+  // with ">= 0". I.e., value1 - value2 - 1 >= 0.
+  SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
+  if (cmp == LT || cmp == LE) {
+    eq[getPos(value1, dim1)]++;
+    eq[getPos(value2, dim2)]--;
+  } else if (cmp == GT || cmp == GE) {
+    eq[getPos(value1, dim1)]--;
+    eq[getPos(value2, dim2)]++;
+  } else {
+    llvm_unreachable("unsupported comparison operator");
+  }
+  if (cmp == LE || cmp == GE)
+    eq[cstr.getNumDimAndSymbolVars()] -= 1;
+
+  // Add inequality to the constraint set and check if it made the constraint
+  // set empty.
+  int64_t ineqPos = cstr.getNumInequalities();
+  cstr.addInequality(eq);
+  bool isEmpty = cstr.isEmpty();
+  cstr.removeInequality(ineqPos);
+  return isEmpty;
+}
+
 FailureOr<bool>
 ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
                                    std::optional<int64_t> dim1,
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index e4d71415924994..0ea06737886d41 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
-// RUN:     -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN:     -verify-diagnostics -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @scf_for(
 //  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
   "test.some_use"(%reify1) : (index) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_if_constant(
+func.func @scf_if_constant(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %r = scf.if %c -> index {
+    scf.yield %c4 : index
+  } else {
+    scf.yield %c9 : index
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK-LABEL: func @scf_if_dynamic(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
+  %c4 = arith.constant 4 : index
+  %r = scf.if %c -> index {
+    %add1 = arith.addi %a, %b : index
+    scf.yield %add1 : index
+  } else {
+    %add2 = arith.addi %b, %c4 : index
+    %add3 = arith.addi %add2, %a : index
+    scf.yield %add3 : index
+  }
+
+  // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+  // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: "test.some_use"(%[[lb]], %[[ub]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
+  %r = scf.if %c -> index {
+    scf.yield %a : index
+  } else {
+    scf.yield %b : index
+  }
+  // The reified bound would be min(%a, %b). min/max expressions are not
+  // supported in reified bounds.
+  // expected-error @below{{could not reify bound}}
+  %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  "test.some_use"(%reify1) : (index) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_tensor_dim(
+func.func @scf_if_tensor_dim(%c : i1) {
+  // CHECK: arith.constant 4 : index
+  // CHECK: arith.constant 9 : index
+  %c4 = arith.constant 4 : index
+  %c9 = arith.constant 9 : index
+  %t1 = tensor.empty(%c4) : tensor<?xf32>
+  %t2 = tensor.empty(%c9) : tensor<?xf32>
+  %r = scf.if %c -> tensor<?xf32> {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    scf.yield %t2 : tensor<?xf32>
+  }
+
+  // CHECK: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+  "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+  return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @scf_if_eq(
+//  CHECK-SAME:     %[[a:.*]]: index, %[[b:....
[truncated]

/// Traverse the IR starting from the given value/dim and add populate
/// constraints as long as the currently set stop condition holds. Also
/// processes all values/dims that are already on the worklist.
void populateConstraints(Value value, std::optional<int64_t> dim);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this may clash with #83876. I'm going to rebase this PR when #83876 has been merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll land it soon :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased all PRs. It would be great if you could also review #86097, then I can start merging PRs. You probably know the codebase best out of all reviewers that I added to that PR.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_scf_if branch 2 times, most recently from 0c94128 to 21b4e40 Compare March 21, 2024 08:10
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/value_bounds_stop_fn_constr March 21, 2024 08:11
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from db3dde1 to 305001b Compare March 22, 2024 02:02
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_scf_if branch from 21b4e40 to 8057ddd Compare March 22, 2024 02:04
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just minor nits that you can address before landing. Thanks!

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h Outdated Show resolved Hide resolved
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp Outdated Show resolved Hide resolved
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp Outdated Show resolved Hide resolved
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp Outdated Show resolved Hide resolved
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from 305001b to ad1b2ac Compare March 23, 2024 05:58
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_scf_if branch from 8057ddd to b4bab14 Compare March 23, 2024 06:00
matthias-springer added a commit that referenced this pull request Mar 23, 2024
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_stop_fn_constr branch from ad1b2ac to 6dcdd66 Compare April 4, 2024 08:01
Base automatically changed from users/matthias-springer/value_bounds_stop_fn_constr to main April 4, 2024 08:05
…or branches

This commit adds support for `scf.if` to `ValueBoundsConstraintSet`.

Example:
```
%0 = scf.if ... -> index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}
```

The following constraints hold for %0:
* %0 >= min(%a, %b)
* %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:
* %0 >= %a
* %0 <= %b

This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_scf_if branch from b4bab14 to cec5571 Compare April 5, 2024 04:07
@matthias-springer matthias-springer merged commit 6b30ffe into main Apr 5, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_scf_if branch April 5, 2024 04:14
matthias-springer added a commit that referenced this pull request Apr 5, 2024
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
matthias-springer added a commit that referenced this pull request Apr 5, 2024
This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.
@joker-eph
Copy link
Collaborator

This broke the gcc-7 build: https://lab.llvm.org/buildbot/#/builders/264/builds/9087/steps/6/logs/FAIL__MLIR__value-bounds-op-interface-impl_mlir

# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir:148:12: error: CHECK: expected string not found in input
# |  // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
# |            ^
# | <stdin>:92:66: note: scanning from here
# |  func.func @scf_if_dynamic(%arg0: index, %arg1: index, %arg2: i1) {
# |                                                                  ^
# | <stdin>:92:66: note: with "$map" equal to "map"
# |  func.func @scf_if_dynamic(%arg0: index, %arg1: index, %arg2: i1) {
# |                                                                  ^
# | <stdin>:92:66: note: with "a" equal to "arg0"
# |  func.func @scf_if_dynamic(%arg0: index, %arg1: index, %arg2: i1) {
# |                                                                  ^
# | <stdin>:92:66: note: with "b" equal to "arg1"
# |  func.func @scf_if_dynamic(%arg0: index, %arg1: index, %arg2: i1) {
# |                                                                  ^
# | <stdin>:102:2: note: possible intended match here
# |  %1 = affine.apply #map()[%arg1, %arg0]
# |  ^

joker-eph added a commit that referenced this pull request Apr 5, 2024
…nches) (#85895)"

This reverts commit 6b30ffe.

gcc7 bot is broken
@joker-eph
Copy link
Collaborator

I reverted your last 3 commits @matthias-springer , please reland with the fix!

matthias-springer added a commit that referenced this pull request Apr 6, 2024
The C++ standard does not specify an evaluation order for addition/...
operands. E.g., in `a() + b()`, the compiler is free to evaluate `a` or
`b` first.

This lead to different `mlir-opt` outputs in #85895. (FileCheck passed
when compiled with LLVM but failed when compiled with gcc.)
matthias-springer added a commit that referenced this pull request Apr 6, 2024
The C++ standard does not specify an evaluation order for addition/...
operands. E.g., in `a() + b()`, the compiler is free to evaluate `a` or
`b` first.

This lead to different `mlir-opt` outputs in #85895. (FileCheck passed
when compiled with LLVM but failed when compiled with gcc.)
matthias-springer added a commit that referenced this pull request Apr 6, 2024
…87860)

This commit adds support for `scf.if` to `ValueBoundsConstraintSet`.

Example:
```
%0 = scf.if ... -> index {
  scf.yield %a : index
} else {
  scf.yield %b : index
}
```

The following constraints hold for %0:
* %0 >= min(%a, %b)
* %0 <= max(%a, %b)

Such constraints cannot be added to the constraint set; min/max is not
supported by `IntegerRelation`. However, if we know which one of %a and
%b is larger, we can add constraints for %0. E.g., if %a <= %b:
* %0 >= %a
* %0 <= %b

This commit required a few minor changes to the
`ValueBoundsConstraintSet` infrastructure, so that values can be
compared while we are still in the process of traversing the IR/adding
constraints.

Note: This is a re-upload of #85895, which was reverted. The bug that
caused the failure was fixed in #87859.
matthias-springer added a commit that referenced this pull request Apr 6, 2024
This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.
matthias-springer added a commit that referenced this pull request Apr 7, 2024
This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.

Note: This is a re-upload of #86383.
matthias-springer added a commit that referenced this pull request Apr 11, 2024
#86915)

This commit adds a new public API to `ValueBoundsOpInterface` to compare
values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT.

The new `ValueBoundsOpInterface::compare` API replaces and generalizes
`ValueBoundsOpInterface::areEqual`. Not only does it provide additional
comparison operators, it also works in cases where the difference
between the two values/dims is non-constant. The previous implementation
of `areEqual` used to compute a constant bound of `val1 - val2` (check
if it `== 0` or `!= 0`).

Note: This commit refactors, generalizes and adds a public API for
value/dim comparison. The comparison functionality itself was introduced
in #85895 and is already in use for analyzing `scf.if`.

In the long term, this improvement will allow for a more powerful
analysis of subset ops. A future commit will update
`areOverlappingSlices` to use the new comparison API.
(`areEquivalentSlices` is already using the new API.) This will improve
subset equivalence/disjointness checks with non-constant
offsets/sizes/strides.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants