Skip to content

Commit

Permalink
[mlir][SCF] ValueBoundsOpInterface: Support scf.for results and ite…
Browse files Browse the repository at this point in the history
…r_args

If an `scf.for` loop yields an equal index-typed value or a shaped value with the same dimension sizes (in comparison to the corresponding iter_arg), bounds can be computed for the iter_arg and the OpResult of the `scf.for` op.

Differential Revision: https://reviews.llvm.org/D146306
  • Loading branch information
matthias-springer committed Apr 7, 2023
1 parent 46e409c commit c3f5fd7
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 30 deletions.
24 changes: 19 additions & 5 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Expand Up @@ -84,24 +84,38 @@ class ValueBoundsConstraintSet {
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
using StopConditionFn = function_ref<bool(Value)>;
///
/// The first parameter of the function is the shaped value/index-typed
/// value. The second parameter is the dimension in case of a shaped value.
using StopConditionFn =
function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;

/// Compute a bound for the given index-typed value or shape dimension size.
/// The computed bound is stored in `resultMap`. The operands of the bound are
/// stored in `mapOperands`. An operand is either an index-type SSA value
/// or a shaped value and a dimension.
///
/// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
/// is computed in terms of values for which `stopCondition` evaluates to
/// "true". To that end, the backward slice (reverse use-def chain) of the
/// given value is visited in a worklist-driven manner and the constraint set
/// is populated according to `ValueBoundsOpInterface` for each visited value.
/// is computed in terms of values/dimensions for which `stopCondition`
/// evaluates to "true". To that end, the backward slice (reverse use-def
/// chain) of the given value is visited in a worklist-driven manner and the
/// constraint set is populated according to `ValueBoundsOpInterface` for each
/// visited value.
static LogicalResult computeBound(AffineMap &resultMap,
ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim,
StopConditionFn stopCondition);

/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult computeBound(AffineMap &resultMap,
ValueDimList &mapOperands,
presburger::BoundType type, Value value,
std::optional<int64_t> dim,
ValueDimList dependencies);

/// Compute a constant bound for the given index-typed value or shape
/// dimension size.
///
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
Expand Up @@ -21,19 +21,19 @@ FailureOr<OpFoldResult> mlir::reifyValueBound(OpBuilder &b, Location loc,
std::optional<int64_t> dim) {
// We are trying to reify a bound for `value`. Construct a stop condition that
// evaluates to "true" for any SSA value expect for `value`. I.e., the bound
// will be computed in terms of any SSA values expect for `value`. The first
// will be computed in terms of any SSA values except for `value`. The first
// such values are operands of the owner of `value`.
auto stopCondition = [&](Value v) {
auto stopCondition = [&](Value v, std::optional<int64_t> d) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim, stopCondition);
}

FailureOr<OpFoldResult>
mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim,
function_ref<bool(Value)> stopCondition) {
FailureOr<OpFoldResult> mlir::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
std::optional<int64_t> dim,
function_ref<bool(Value, std::optional<int64_t>)> stopCondition) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
Expand Down
90 changes: 83 additions & 7 deletions mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Expand Up @@ -12,29 +12,105 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

using namespace mlir;
using presburger::BoundType;

namespace mlir {
namespace scf {
namespace {

struct ForOpInterface
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {

/// Populate bounds of values/dimensions for iter_args/OpResults.
static void populateIterArgBounds(scf::ForOp forOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
// `value` is an iter_arg or an OpResult.
int64_t iterArgIdx;
if (auto iterArg = value.dyn_cast<BlockArgument>()) {
iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
} else {
iterArgIdx = value.cast<OpResult>().getResultNumber();
}

// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
assert(forOp.getLoopBody().hasOneBlock() &&
"multiple blocks not supported");
Value yieldedValue =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
.getOperand(iterArgIdx);
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];

auto addEqBound = [&]() {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
cstr.bound(value) == initArg;
}
};

if (yieldedValue == iterArg) {
addEqBound();
return;
}

// Compute EQ bound for yielded value.
AffineMap bound;
ValueDimList boundOperands;
LogicalResult status = ValueBoundsConstraintSet::computeBound(
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
[&](Value v, std::optional<int64_t> d) {
// Stop when reaching a block argument of the loop body.
if (auto bbArg = v.dyn_cast<BlockArgument>())
return bbArg.getOwner()->getParentOp() == forOp;
// Stop when reaching a value that is defined outside of the loop. It
// is impossible to reach an iter_arg from there.
Operation *op = v.getDefiningOp();
return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
});
if (failed(status))
return;
if (bound.getNumResults() != 1)
return;

// Check if computed bound equals the corresponding iter_arg.
Value singleValue = nullptr;
std::optional<int64_t> singleDim = std::nullopt;
if (auto dimExpr = bound.getResult(0).dyn_cast<AffineDimExpr>()) {
int64_t idx = dimExpr.getPosition();
singleValue = boundOperands[idx].first;
singleDim = boundOperands[idx].second;
} else if (auto symExpr = bound.getResult(0).dyn_cast<AffineSymbolExpr>()) {
int64_t idx = symExpr.getPosition() + bound.getNumDims();
singleValue = boundOperands[idx].first;
singleDim = boundOperands[idx].second;
}
if (singleValue == iterArg && singleDim == dim)
addEqBound();
}

void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
// Only IV is supported at the moment.
if (value != forOp.getInductionVar())

if (value == forOp.getInductionVar()) {
// TODO: Take into account step size.
cstr.bound(value) >= forOp.getLowerBound();
cstr.bound(value) < forOp.getUpperBound();
return;
}

// TODO: Take into account step size.
cstr.bound(value) >= forOp.getLowerBound();
cstr.bound(value) < forOp.getUpperBound();
// Handle iter_args and OpResults.
populateIterArgBounds(forOp, value, std::nullopt, cstr);
}

void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
// iter_arg / return value not supported.
return;
auto forOp = cast<ForOp>(op);
// Handle iter_args and OpResults.
populateIterArgBounds(forOp, value, dim, cstr);
}
};

Expand Down
28 changes: 22 additions & 6 deletions mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
Expand Up @@ -164,7 +164,8 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
}

// Do not process any further if the stop condition is met.
if (stopCondition(value))
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
if (stopCondition(value, maybeDim))
continue;

// Query `ValueBoundsOpInterface` for constraints. New items may be added to
Expand Down Expand Up @@ -213,12 +214,14 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
assert(!stopCondition(value, dim) &&
"stop condition should not be satisfied for starting point");
#endif // NDEBUG

Builder b(value.getContext());
mapOperands.clear();

if (stopCondition(value)) {
if (stopCondition(value, dim)) {
// Special case: If the stop condition is satisfied for the input
// value/dimension, directly return it.
mapOperands.push_back(std::make_pair(value, dim));
Expand All @@ -239,7 +242,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
// Do not project out `valueDim`.
if (valueDim == p)
return false;
return !stopCondition(p.first);
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
return !stopCondition(p.first, maybeDim);
});

// Compute lower and upper bounds for `valueDim`.
Expand Down Expand Up @@ -338,6 +343,16 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return success();
}

LogicalResult ValueBoundsConstraintSet::computeBound(
AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
Value value, std::optional<int64_t> dim, ValueDimList dependencies) {
return computeBound(resultMap, mapOperands, type, value, dim,
[&](Value v, std::optional<int64_t> d) {
return llvm::is_contained(dependencies,
std::make_pair(v, d));
});
}

FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, Value value, std::optional<int64_t> dim,
StopConditionFn stopCondition) {
Expand All @@ -354,9 +369,10 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
} else {
// No stop condition specified: Keep adding constraints until a bound could
// be computed.
cstr.processWorklist(/*stopCondition=*/[&](Value v) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
});
cstr.processWorklist(
/*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
});
}

// Compute constant bound for `valueDim`.
Expand Down
92 changes: 92 additions & 0 deletions mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
Expand Up @@ -12,3 +12,95 @@ func.func @scf_for(%a: index, %b: index, %c: index) {
}
return
}

// -----

// CHECK-LABEL: func @scf_for_index_result_small(
// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
// CHECK: "test.some_use"(%[[i]])
// CHECK: "test.some_use"(%[[i]])
func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) {
%0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
%1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
"test.some_use"(%1) : (index) -> ()
scf.yield %arg : index
}
%2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
"test.some_use"(%2) : (index) -> ()
return
}

// -----

// CHECK-LABEL: func @scf_for_index_result(
// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
// CHECK: "test.some_use"(%[[i]])
// CHECK: "test.some_use"(%[[i]])
func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) {
%0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index {
%add = arith.addi %arg, %a : index
%sub = arith.subi %add, %a : index

%1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index)
"test.some_use"(%1) : (index) -> ()
scf.yield %sub : index
}
%2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
"test.some_use"(%2) : (index) -> ()
return
}

// -----

// CHECK-LABEL: func @scf_for_tensor_result_small(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
// CHECK: "test.some_use"(%[[dim]])
// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
// CHECK: "test.some_use"(%[[dim]])
func.func @scf_for_tensor_result_small(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
%0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
%1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
"test.some_use"(%1) : (index) -> ()
scf.yield %arg : tensor<?xf32>
}
%2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
"test.some_use"(%2) : (index) -> ()
return
}

// -----

// CHECK-LABEL: func @scf_for_tensor_result(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
// CHECK: "test.some_use"(%[[dim]])
// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
// CHECK: "test.some_use"(%[[dim]])
func.func @scf_for_tensor_result(%t: tensor<?xf32>, %a: index, %b: index, %c: index) {
%cst = arith.constant 5.0 : f32
%0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor<?xf32> {
%filled = linalg.fill ins(%cst : f32) outs(%arg : tensor<?xf32>) -> tensor<?xf32>
%1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
"test.some_use"(%1) : (index) -> ()
scf.yield %filled : tensor<?xf32>
}
%2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
"test.some_use"(%2) : (index) -> ()
return
}

// -----

func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: index, %b: index, %c: index) {
%cst = arith.constant 5.0 : f32
%r1, %r2 = scf.for %iv = %a to %b step %c iter_args(%arg1 = %t1, %arg2 = %t2) -> (tensor<?xf32>, tensor<?xf32>) {
%filled1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
%filled2 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
scf.yield %filled2, %filled1 : tensor<?xf32>, tensor<?xf32>
}
// expected-error @below{{could not reify bound}}
%reify1 = "test.reify_bound"(%r1) {type = "EQ", dim = 0} : (tensor<?xf32>) -> (index)
"test.some_use"(%reify1) : (index) -> ()
return
}
2 changes: 2 additions & 0 deletions mlir/test/lib/Dialect/Affine/CMakeLists.txt
Expand Up @@ -25,5 +25,7 @@ add_mlir_library(MLIRAffineTransformsTestPasses
MLIRIR
MLIRPass
MLIRSupport
MLIRMemRefDialect
MLIRTensorDialect
MLIRVectorUtils
)
16 changes: 10 additions & 6 deletions mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Expand Up @@ -9,6 +9,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -33,7 +35,8 @@ struct TestReifyValueBounds
TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
registry
.insert<AffineDialect, tensor::TensorDialect, memref::MemRefDialect>();
}

void runOnOperation() override;
Expand Down Expand Up @@ -101,13 +104,14 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,

// Prepare stop condition. By default, reify in terms of the op's
// operands. No stop condition is used when a constant was requested.
std::function<bool(Value)> stopCondition = [&](Value v) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
std::function<bool(Value, std::optional<int64_t>)> stopCondition =
[&](Value v, std::optional<int64_t> d) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
stopCondition = stopCondition = [](Value v) {
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
auto bbArg = v.dyn_cast<BlockArgument>();
if (!bbArg)
return false;
Expand Down

0 comments on commit c3f5fd7

Please sign in to comment.