Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Arith] ValueBoundsOpInterface: Support arith.select #86383

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds a ValueBoundsOpInterface implementation for arith.select. The implementation is almost identical to scf.if (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a ValueBoundsOpInterface implementation for arith.select. The implementation is almost identical to scf.if (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.


Full diff: https://github.com/llvm/llvm-project/pull/86383.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+70)
  • (modified) mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir (+31)
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 9c6b50e767ea26..bb7b9c939fcb09 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,6 +66,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(trueValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     falseValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(falseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     trueValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
     arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
     arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
     arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+    arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
   %0 = "test.reify_bound"(%c5) : (index) -> (index)
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @arith_select(
+func.func @arith_select(%c: i1) -> (index, index) {
+  // CHECK: arith.constant 5 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: arith.constant 9 : index
+  %c9 = arith.constant 9 : index
+  %r = arith.select %c, %c5, %c9 : index
+  // CHECK: %[[c5:.*]] = arith.constant 5 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: return %[[c5]], %[[c10]]
+  return %0, %1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_select_elementwise(
+//  CHECK-SAME:     %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
+func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
+  %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
+  // CHECK: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
+  %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: return %[[dim]]
+  return %0 : index
+}

@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_scf_for branch from 3c4adb5 to f8f4249 Compare April 5, 2024 04:17
Base automatically changed from users/matthias-springer/value_bounds_scf_for to main April 5, 2024 04:27
This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/value_bounds_arith_select branch from 680d04d to 4ef0c78 Compare April 5, 2024 04:31
@matthias-springer matthias-springer merged commit 62b58d3 into main Apr 5, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_arith_select branch April 5, 2024 04:39
joker-eph added a commit that referenced this pull request Apr 5, 2024
matthias-springer added a commit that referenced this pull request Apr 6, 2024
This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.
matthias-springer added a commit that referenced this pull request Apr 7, 2024
This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.

Note: This is a re-upload of #86383.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants