Skip to content

Commit

Permalink
[mlir][affine][NFC] Split reifyValueBound in two functions
Browse files Browse the repository at this point in the history
There are now two entry points. One for shaped values and one for index-typed values. This addresses a comment in D146524.

Differential Revision: https://reviews.llvm.org/D147987
  • Loading branch information
matthias-springer committed Apr 18, 2023
1 parent c446f2d commit 912fedf
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 42 deletions.
38 changes: 19 additions & 19 deletions mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
Expand Up @@ -49,20 +49,9 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);

/// Reify a bound for the given index-typed value or shape dimension size in
/// terms of the owning op's operands. `dim` must be `nullopt` if and only if
/// `value` is index-typed.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
presburger::BoundType type, Value value,
std::optional<int64_t> dim,
bool closedUB = false);

/// Reify a bound for the given index-typed value or shape dimension size in
/// terms of SSA values for which `stopCondition` is met. `dim` must be
/// `nullopt` if and only if `value` is index-typed.
/// Reify a bound for the given index-typed value in terms of SSA values for
/// which `stopCondition` is met. If no stop condition is specified, reify in
/// terms of the operands of the owner op.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
Expand All @@ -77,11 +66,22 @@ FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
/// is an EQ bound for %1.
/// * Otherwise, if the owners of %a, %b or %c do not implement the
/// ValueBoundsOpInterface, no bound can be computed.
FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB = false);
FailureOr<OpFoldResult> reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr,
bool closedUB = false);

/// Reify a bound for the specified dimension of the given shaped value in terms
/// of SSA values for which `stopCondition` is met. If no stop condition is
/// specified, reify in terms of the operands of the owner op.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
FailureOr<OpFoldResult> reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim,
ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr,
bool closedUB = false);

} // namespace mlir

Expand Down
52 changes: 33 additions & 19 deletions mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
Expand Up @@ -15,25 +15,11 @@

using namespace mlir;

FailureOr<OpFoldResult>
mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim, bool closedUB) {
// We are trying to reify a bound for `value`. Construct a stop condition that
// evaluates to "true" for any SSA value expect for `value`. I.e., the bound
// will be computed in terms of any SSA values except for `value`. The first
// such values are operands of the owner of `value`.
auto stopCondition = [&](Value v, std::optional<int64_t> d) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB);
}

FailureOr<OpFoldResult> mlir::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
std::optional<int64_t> dim,
function_ref<bool(Value, std::optional<int64_t>)> stopCondition,
bool closedUB) {
static FailureOr<OpFoldResult>
reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
Value value, std::optional<int64_t> dim,
function_ref<bool(Value, std::optional<int64_t>)> stopCondition,
bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
Expand Down Expand Up @@ -85,3 +71,31 @@ FailureOr<OpFoldResult> mlir::reifyValueBound(
return static_cast<OpFoldResult>(
b.create<AffineApplyOp>(loc, boundMap, operands).getResult());
}

FailureOr<OpFoldResult> mlir::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value except for `value`. I.e., the bound will be computed in terms of
// any SSA values except for `value`. The first such values are operands of
// the owner of `value`.
return v != value;
};
return reifyValueBound(b, loc, type, value, dim,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}

FailureOr<OpFoldResult> mlir::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Expand Up @@ -462,9 +462,9 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
// of the enclosing loops.
for (auto forOp : packingLoops) {
// Compute an upper bound `ubVal` for the upper bound of `forOp`.
FailureOr<OpFoldResult> loopUb = reifyValueBound(
FailureOr<OpFoldResult> loopUb = reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
/*dim=*/std::nullopt, /*stopCondition=*/
/*stopCondition=*/
[&](Value v, std::optional<int64_t> d) {
if (v == forOp.getUpperBound())
return false;
Expand Down
9 changes: 7 additions & 2 deletions mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Expand Up @@ -130,8 +130,13 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
reified =
FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else {
reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value,
dim, stopCondition);
if (dim) {
reified = reifyShapedValueDimBound(rewriter, op->getLoc(), *boundType,
value, *dim, stopCondition);
} else {
reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType,
value, stopCondition);
}
}
if (failed(reified)) {
op->emitOpError("could not reify bound");
Expand Down

0 comments on commit 912fedf

Please sign in to comment.