diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h index 8e840e744064d..1ea7375220815 100644 --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op); /// maximally compose chains of AffineApplyOps. FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); +/// Reify a bound for the given variable in terms of SSA values for which +/// `stopCondition` is met. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, + bool closedUB = false); + /// Reify a bound for the given index-typed value in terms of SSA values for /// which `stopCondition` is met. If no stop condition is specified, reify in /// terms of the operands of the owner op. diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h index 970a52a06a11a..bbc7e5d3e0dd7 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h @@ -24,6 +24,17 @@ enum class BoundType; namespace arith { +/// Reify a bound for the given variable in terms of SSA values for which +/// `stopCondition` is met. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, + bool closedUB = false); + /// Reify a bound for the given index-typed value in terms of SSA values for /// which `stopCondition` is met. If no stop condition is specified, reify in /// terms of the operands of the owner op. diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 1d7bc6ea961cc..ac17ace5a976d 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -15,6 +15,7 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/ExtensibleRTTI.h" #include @@ -111,6 +112,39 @@ class ValueBoundsConstraintSet public: static char ID; + /// A variable that can be added to the constraint set as a "column". The + /// value bounds infrastructure can compute bounds for variables and compare + /// two variables. + /// + /// Internally, a variable is represented as an affine map and operands. + class Variable { + public: + /// Construct a variable for an index-typed attribute or SSA value. + Variable(OpFoldResult ofr); + + /// Construct a variable for an index-typed SSA value. + Variable(Value indexValue); + + /// Construct a variable for a dimension of a shaped value. + Variable(Value shapedValue, int64_t dim); + + /// Construct a variable for an index-typed attribute/SSA value or for a + /// dimension of a shaped value. A non-null dimension must be provided if + /// and only if `ofr` is a shaped value. + Variable(OpFoldResult ofr, std::optional dim); + + /// Construct a variable for a map and its operands. + Variable(AffineMap map, ArrayRef mapOperands); + Variable(AffineMap map, ArrayRef mapOperands); + + MLIRContext *getContext() const { return map.getContext(); } + + private: + friend class ValueBoundsConstraintSet; + AffineMap map; + ValueDimList mapOperands; + }; + /// The stop condition when traversing the backward slice of a shaped value/ /// index-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. @@ -121,35 +155,31 @@ class ValueBoundsConstraintSet using StopConditionFn = std::function /*dim*/, ValueBoundsConstraintSet &cstr)>; - /// Compute a bound for the given index-typed value or shape dimension size. - /// The computed bound is stored in `resultMap`. The operands of the bound are - /// stored in `mapOperands`. An operand is either an index-type SSA value - /// or a shaped value and a dimension. + /// Compute a bound for the given variable. The computed bound is stored in + /// `resultMap`. The operands of the bound are stored in `mapOperands`. An + /// operand is either an index-type SSA value or a shaped value and a + /// dimension. /// - /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound - /// is computed in terms of values/dimensions for which `stopCondition` - /// evaluates to "true". To that end, the backward slice (reverse use-def - /// chain) of the given value is visited in a worklist-driven manner and the - /// constraint set is populated according to `ValueBoundsOpInterface` for each - /// visited value. + /// The bound is computed in terms of values/dimensions for which + /// `stopCondition` evaluates to "true". To that end, the backward slice + /// (reverse use-def chain) of the given value is visited in a worklist-driven + /// manner and the constraint set is populated according to + /// `ValueBoundsOpInterface` for each visited value. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. - static LogicalResult computeBound(AffineMap &resultMap, - ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional dim, - StopConditionFn stopCondition, - bool closedUB = false); + static LogicalResult + computeBound(AffineMap &resultMap, ValueDimList &mapOperands, + presburger::BoundType type, const Variable &var, + StopConditionFn stopCondition, bool closedUB = false); /// Compute a bound in terms of the values/dimensions in `dependencies`. The /// computed bound consists of only constant terms and dependent values (or /// dimension sizes thereof). static LogicalResult computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional dim, ValueDimList dependencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueDimList dependencies, bool closedUB = false); /// Compute a bound in that is independent of all values in `independencies`. /// @@ -161,13 +191,10 @@ class ValueBoundsConstraintSet /// appear in the computed bound. static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, - presburger::BoundType type, Value value, - std::optional dim, ValueRange independencies, - bool closedUB = false); + presburger::BoundType type, const Variable &var, + ValueRange independencies, bool closedUB = false); - /// Compute a constant bound for the given affine map, where dims and symbols - /// are bound to the given operands. The affine map must have exactly one - /// result. + /// Compute a constant bound for the given variable. /// /// This function traverses the backward slice of the given operands in a /// worklist-driven manner until `stopCondition` evaluates to "true". The @@ -182,16 +209,9 @@ class ValueBoundsConstraintSet /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. static FailureOr - computeConstantBound(presburger::BoundType type, Value value, - std::optional dim = std::nullopt, + computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr computeConstantBound( - presburger::BoundType type, AffineMap map, ValueDimList mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); - static FailureOr computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef mapOperands, - StopConditionFn stopCondition = nullptr, bool closedUB = false); /// Compute a constant delta between the given two values. Return "failure" /// if a constant delta could not be determined. @@ -221,9 +241,8 @@ class ValueBoundsConstraintSet /// proven. This could be because the specified relation does in fact not hold /// or because there is not enough information in the constraint set. In other /// words, if we do not know for sure, this function returns "false". - bool populateAndCompare(OpFoldResult lhs, std::optional lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional rhsDim); + bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, + const Variable &rhs); /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the /// specified relation could not be proven. This could be because the @@ -233,24 +252,12 @@ class ValueBoundsConstraintSet /// /// This function keeps traversing the backward slice of lhs/rhs until could /// prove the relation or until it ran out of IR. - static bool compare(OpFoldResult lhs, std::optional lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional rhsDim); - static bool compare(AffineMap lhs, ValueDimList lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ValueDimList rhsOperands); - static bool compare(AffineMap lhs, ArrayRef lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ArrayRef rhsOperands); - - /// Compute whether the given values/dimensions are equal. Return "failure" if + static bool compare(const Variable &lhs, ComparisonOperator cmp, + const Variable &rhs); + + /// Compute whether the given variables are equal. Return "failure" if /// equality could not be determined. - /// - /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are - /// index-typed. - static FailureOr areEqual(OpFoldResult value1, OpFoldResult value2, - std::optional dim1 = std::nullopt, - std::optional dim2 = std::nullopt); + static FailureOr areEqual(const Variable &var1, const Variable &var2); /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. @@ -317,9 +324,6 @@ class ValueBoundsConstraintSet /// /// This function does not analyze any IR and does not populate any additional /// constraints. - bool compareValueDims(OpFoldResult lhs, std::optional lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional rhsDim); bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); /// Given an affine map with a single result (and map operands), add a new @@ -374,6 +378,7 @@ class ValueBoundsConstraintSet /// constraint system. Return the position of the new column. Any operands /// that were not analyzed yet are put on the worklist. int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + int64_t insert(const Variable &var, bool isSymbol = true); /// Project out the given column in the constraint set. void projectOut(int64_t pos); @@ -381,6 +386,8 @@ class ValueBoundsConstraintSet /// Project out all columns for which the condition holds. void projectOut(function_ref condition); + void projectOutAnonymous(std::optional except = std::nullopt); + /// Mapping of columns to values/shape dimensions. SmallVector> positionToValueDim; /// Reverse mapping of values/shape dimensions to columns. diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index e0c3abe7a0f71..82a9fb0d49088 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) { mapOperands.push_back(value1); mapOperands.push_back(value2); affine::fullyComposeAffineMapAndOperands(&map, &mapOperands); - ValueDimList valueDims; - for (Value v : mapOperands) - valueDims.push_back({v, std::nullopt}); return ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::EQ, map, valueDims); + presburger::BoundType::EQ, + ValueBoundsConstraintSet::Variable(map, mapOperands)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 117ee8e8701ad..1a266b72d1f8d 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -16,16 +16,15 @@ using namespace mlir; using namespace mlir::affine; -static FailureOr -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - ValueBoundsConstraintSet::StopConditionFn stopCondition, - bool closedUB) { +FailureOr mlir::affine::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, var, stopCondition, closedUB))) return failure(); // Reify bound. @@ -93,7 +92,7 @@ FailureOr mlir::affine::reifyShapedValueDimBound( // the owner of `value`. return v != value; }; - return reifyValueBound(b, loc, type, value, dim, + return reifyValueBound(b, loc, type, {value, dim}, stopCondition ? stopCondition : reifyToOperands, closedUB); } @@ -105,7 +104,7 @@ FailureOr mlir::affine::reifyIndexValueBound( ValueBoundsConstraintSet &cstr) { return v != value; }; - return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + return reifyValueBound(b, loc, type, value, stopCondition ? stopCondition : reifyToOperands, closedUB); } diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp index f0d43808bc45d..7cfcc4180539c 100644 --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -107,9 +107,9 @@ struct SelectOpInterface // If trueValue <= falseValue: // * result <= falseValue // * result >= trueValue - if (cstr.compare(trueValue, dim, + if (cstr.compare(/*lhs=*/{trueValue, dim}, ValueBoundsConstraintSet::ComparisonOperator::LE, - falseValue, dim)) { + /*rhs=*/{falseValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); @@ -121,9 +121,9 @@ struct SelectOpInterface // If falseValue <= trueValue: // * result <= trueValue // * result >= falseValue - if (cstr.compare(falseValue, dim, + if (cstr.compare(/*lhs=*/{falseValue, dim}, ValueBoundsConstraintSet::ComparisonOperator::LE, - trueValue, dim)) { + /*rhs=*/{trueValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 79fabd6ed2e99..f87f3d6350c02 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern { return failure(); FailureOr ub = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, in, /*dim=*/std::nullopt, + presburger::BoundType::UB, in, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(ub)) return failure(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index fad221288f190..5fb7953f93700 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, return buildExpr(map.getResult(0)); } -static FailureOr -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - ValueBoundsConstraintSet::StopConditionFn stopCondition, - bool closedUB) { +FailureOr mlir::arith::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, var, stopCondition, closedUB))) return failure(); // Materialize tensor.dim/memref.dim ops. @@ -128,7 +127,7 @@ FailureOr mlir::arith::reifyShapedValueDimBound( // the owner of `value`. return v != value; }; - return reifyValueBound(b, loc, type, value, dim, + return reifyValueBound(b, loc, type, {value, dim}, stopCondition ? stopCondition : reifyToOperands, closedUB); } @@ -140,7 +139,7 @@ FailureOr mlir::arith::reifyIndexValueBound( ValueBoundsConstraintSet &cstr) { return v != value; }; - return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + return reifyValueBound(b, loc, type, value, stopCondition ? stopCondition : reifyToOperands, closedUB); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index 8c4b70db24898..518d2e138c02a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, // Otherwise, try to compute a constant upper bound for the size value. FailureOr upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, opOperand->get(), - /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); + presburger::BoundType::UB, + {opOperand->get(), + /*dim=*/i}, + /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ac896d6c30d04..71eb59d40836c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -257,14 +257,12 @@ FailureOr mlir::linalg::promoteSubviewAsNewBuffer( if (auto attr = llvm::dyn_cast_if_present(rangeValue.size)) { size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); } else { - Value materializedSize = - getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); FailureOr upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt, + presburger::BoundType::UB, rangeValue.size, /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) - ? materializedSize + ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) : b.create(loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 10ba508265e7b..1f06318cbd60e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -23,12 +23,11 @@ static FailureOr makeIndependent(OpBuilder &b, Location loc, ValueRange independencies) { if (ofr.is()) return ofr; - Value value = ofr.get(); AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( - boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, + /*closedUB=*/true))) return failure(); return affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 087ffc438a830..17a1c016ea16d 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -61,12 +61,13 @@ struct ForOpInterface // An EQ constraint can be added if the yielded value (dimension size) // equals the corresponding block argument (dimension size). if (cstr.populateAndCompare( - yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ, - iterArg, dim)) { + /*lhs=*/{yieldedValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::EQ, + /*rhs=*/{iterArg, dim})) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { - cstr.bound(value) == initArg; + cstr.bound(value) == cstr.getExpr(initArg); } } } @@ -113,8 +114,9 @@ struct IfOpInterface // * result <= elseValue // * result >= thenValue if (cstr.populateAndCompare( - thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { + /*lhs=*/{thenValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{elseValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); @@ -127,8 +129,9 @@ struct IfOpInterface // * result <= thenValue // * result >= elseValue if (cstr.populateAndCompare( - elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { + /*lhs=*/{elseValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{thenValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 67080d8e301c1..d25efcf50ec56 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, info.isAlignedToInnerTileSize = false; FailureOr cstSize = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, - getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt, + presburger::BoundType::UB, tileSize, /*stopCondition=*/nullptr, /*closedUB=*/true); std::optional cstInnerSize = getConstantIntValue(innerTileSize); if (!failed(cstSize) && cstInnerSize) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 721730862d49b..a89ce20048dff 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -28,7 +28,8 @@ static FailureOr makeIndependent(OpBuilder &b, Location loc, ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + independencies, + /*closedUB=*/true))) return failure(); return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 2dd91e2f7a170..15381ec520e21 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { continue; } FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), srcDim, resultDim); + {op.getSource(), srcDim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++srcDim; @@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { continue; } FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), dim, resultDim); + {op.getSource(), dim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++resultDim; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index ffa4c0b55cad7..87937591e60ad 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -25,6 +25,12 @@ namespace mlir { #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" } // namespace mlir +static Operation *getOwnerOfValue(Value value) { + if (auto bbArg = dyn_cast(value)) + return bbArg.getOwner()->getParentOp(); + return value.getDefiningOp(); +} + HyperrectangularSlice::HyperrectangularSlice(ArrayRef offsets, ArrayRef sizes, ArrayRef strides) @@ -67,6 +73,83 @@ static std::optional getConstantIntValue(OpFoldResult ofr) { return std::nullopt; } +ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr) + : Variable(ofr, std::nullopt) {} + +ValueBoundsConstraintSet::Variable::Variable(Value indexValue) + : Variable(static_cast(indexValue)) {} + +ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim) + : Variable(static_cast(shapedValue), std::optional(dim)) {} + +ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr, + std::optional dim) { + Builder b(ofr.getContext()); + if (auto constInt = ::getConstantIntValue(ofr)) { + assert(!dim && "expected no dim for index-typed values"); + map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, + b.getAffineConstantExpr(*constInt)); + return; + } + Value value = cast(ofr); +#ifndef NDEBUG + if (dim) { + assert(isa(value.getType()) && "expected shaped type"); + } else { + assert(value.getType().isIndex() && "expected index type"); + } +#endif // NDEBUG + map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, + b.getAffineSymbolExpr(0)); + mapOperands.emplace_back(value, dim); +} + +ValueBoundsConstraintSet::Variable::Variable(AffineMap map, + ArrayRef mapOperands) { + assert(map.getNumResults() == 1 && "expected single result"); + + // Turn all dims into symbols. + Builder b(map.getContext()); + SmallVector dimReplacements, symReplacements; + for (int64_t i = 0, e = map.getNumDims(); i < e; ++i) + dimReplacements.push_back(b.getAffineSymbolExpr(i)); + for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i) + symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims())); + AffineMap tmpMap = map.replaceDimsAndSymbols( + dimReplacements, symReplacements, /*numResultDims=*/0, + /*numResultSyms=*/map.getNumSymbols() + map.getNumDims()); + + // Inline operands. + DenseMap replacements; + for (auto [index, var] : llvm::enumerate(mapOperands)) { + assert(var.map.getNumResults() == 1 && "expected single result"); + assert(var.map.getNumDims() == 0 && "expected only symbols"); + SmallVector symReplacements; + for (auto valueDim : var.mapOperands) { + auto it = llvm::find(this->mapOperands, valueDim); + if (it != this->mapOperands.end()) { + // There is already a symbol for this operand. + symReplacements.push_back(b.getAffineSymbolExpr( + std::distance(this->mapOperands.begin(), it))); + } else { + // This is a new operand: add a new symbol. + symReplacements.push_back( + b.getAffineSymbolExpr(this->mapOperands.size())); + this->mapOperands.push_back(valueDim); + } + } + replacements[b.getAffineSymbolExpr(index)] = + var.map.getResult(0).replaceSymbols(symReplacements); + } + this->map = tmpMap.replace(replacements, /*numResultDims=*/0, + /*numResultSyms=*/this->mapOperands.size()); +} + +ValueBoundsConstraintSet::Variable::Variable(AffineMap map, + ArrayRef mapOperands) + : Variable(map, llvm::map_to_vector(mapOperands, + [](Value v) { return Variable(v); })) {} + ValueBoundsConstraintSet::ValueBoundsConstraintSet( MLIRContext *ctx, StopConditionFn stopCondition) : builder(ctx), stopCondition(stopCondition) { @@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value, assert(!valueDimToPosition.contains(valueDim) && "already mapped"); int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) : cstr.appendVar(VarKind::SetDim); + LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos + << " for: " << value + << " (dim: " << dim.value_or(kIndexValue) + << ", owner: " << getOwnerOfValue(value)->getName() + << ")\n"); positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) @@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value, int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) : cstr.appendVar(VarKind::SetDim); + LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos + << "\n"); positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) @@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, return pos; } +int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) { + return insert(var.map, var.mapOperands, isSymbol); +} + int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional dim) const { #ifndef NDEBUG @@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, cast(value).getOwner()->isEntryBlock()) && "unstructured control flow is not supported"); #endif // NDEBUG - + LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value + << " (dim: " << dim.value_or(kIndexValue) + << ", owner: " << getOwnerOfValue(value)->getName() + << ")\n"); auto it = valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); assert(it != valueDimToPosition.end() && "expected mapped entry"); @@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value, return it != valueDimToPosition.end(); } -static Operation *getOwnerOfValue(Value value) { - if (auto bbArg = dyn_cast(value)) - return bbArg.getOwner()->getParentOp(); - return value.getDefiningOp(); -} - void ValueBoundsConstraintSet::processWorklist() { LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); while (!worklist.empty()) { @@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut( } } +void ValueBoundsConstraintSet::projectOutAnonymous( + std::optional except) { + int64_t nextPos = 0; + while (nextPos < static_cast(positionToValueDim.size())) { + if (positionToValueDim[nextPos].has_value() || except == nextPos) { + ++nextPos; + } else { + projectOut(nextPos); + // The column was projected out so another column is now at that position. + // Do not increase the counter. + } + } +} + LogicalResult ValueBoundsConstraintSet::computeBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional dim, StopConditionFn stopCondition, - bool closedUB) { -#ifndef NDEBUG - assertValidValueDim(value, dim); -#endif // NDEBUG - + const Variable &var, StopConditionFn stopCondition, bool closedUB) { + MLIRContext *ctx = var.getContext(); int64_t ubAdjustment = closedUB ? 0 : 1; - Builder b(value.getContext()); + Builder b(ctx); mapOperands.clear(); // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - ValueBoundsConstraintSet cstr(value.getContext(), stopCondition); - assert(!stopCondition(value, dim, cstr) && - "stop condition should not be satisfied for starting point"); - int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); + ValueBoundsConstraintSet cstr(ctx, stopCondition); + int64_t pos = cstr.insert(var, /*isSymbol=*/false); + assert(pos == 0 && "expected first column"); cstr.processWorklist(); // Project out all variables (apart from `valueDim`) that do not match the // stop condition. cstr.projectOut([&](ValueDim p) { - // Do not project out `valueDim`. - if (valueDim == p) - return false; auto maybeDim = p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); return !stopCondition(p.first, maybeDim, cstr); }); + cstr.projectOutAnonymous(/*except=*/pos); // Compute lower and upper bounds for `valueDim`. SmallVector lb(1), ub(1); - cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, + cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub, /*closedUB=*/true); // Note: There are TODOs in the implementation of `getSliceBounds`. In such a @@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound( LogicalResult ValueBoundsConstraintSet::computeDependentBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional dim, ValueDimList dependencies, - bool closedUB) { + const Variable &var, ValueDimList dependencies, bool closedUB) { return computeBound( - resultMap, mapOperands, type, value, dim, + resultMap, mapOperands, type, var, [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, @@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound( LogicalResult ValueBoundsConstraintSet::computeIndependentBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, - Value value, std::optional dim, ValueRange independencies, - bool closedUB) { + const Variable &var, ValueRange independencies, bool closedUB) { // Return "true" if the given value is independent of all values in // `independencies`. I.e., neither the value itself nor any value in the // backward slice (reverse use-def chain) is contained in `independencies`. @@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( // Reify bounds in terms of any independent values. return computeBound( - resultMap, mapOperands, type, value, dim, + resultMap, mapOperands, type, var, [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { return isIndependent(v); }, @@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( } FailureOr ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, Value value, std::optional dim, - StopConditionFn stopCondition, bool closedUB) { -#ifndef NDEBUG - assertValidValueDim(value, dim); -#endif // NDEBUG - - AffineMap map = - AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - Builder(value.getContext()).getAffineDimExpr(0)); - return computeConstantBound(type, map, {{value, dim}}, stopCondition, - closedUB); -} - -FailureOr ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, AffineMap map, ArrayRef operands, + presburger::BoundType type, const Variable &var, StopConditionFn stopCondition, bool closedUB) { - ValueDimList valueDims; - for (Value v : operands) { - assert(v.getType().isIndex() && "expected index type"); - valueDims.emplace_back(v, std::nullopt); - } - return computeConstantBound(type, map, valueDims, stopCondition, closedUB); -} - -FailureOr ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType type, AffineMap map, ValueDimList operands, - StopConditionFn stopCondition, bool closedUB) { - assert(map.getNumResults() == 1 && "expected affine map with one result"); - // Default stop condition if none was specified: Keep adding constraints until // a bound could be computed. int64_t pos = 0; @@ -562,8 +630,8 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( }; ValueBoundsConstraintSet cstr( - map.getContext(), stopCondition ? stopCondition : defaultStopCondition); - pos = cstr.populateConstraints(map, operands); + var.getContext(), stopCondition ? stopCondition : defaultStopCondition); + pos = cstr.populateConstraints(var.map, var.mapOperands); assert(pos == 0 && "expected `map` is the first column"); // Compute constant bound for `valueDim`. @@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, Builder b(value1.getContext()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); - return computeConstantBound(presburger::BoundType::EQ, map, - {{value1, dim1}, {value2, dim2}}); + return computeConstantBound(presburger::BoundType::EQ, + Variable(map, {{value1, dim1}, {value2, dim2}})); } -bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, - std::optional lhsDim, - ComparisonOperator cmp, - OpFoldResult rhs, - std::optional rhsDim) { -#ifndef NDEBUG - if (auto lhsVal = dyn_cast(lhs)) - assertValidValueDim(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast(rhs)) - assertValidValueDim(rhsVal, rhsDim); -#endif // NDEBUG - +bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, + ComparisonOperator cmp, + int64_t rhsPos) { // This function returns "true" if "lhs CMP rhs" is proven to hold. // // Example for ComparisonOperator::LE and index-typed values: We would like to @@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, return false; } - // EQ can be expressed as LE and GE. - if (cmp == EQ) - return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && - compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); - - // Construct inequality. For the above example: lhs > rhs. - // `IntegerRelation` inequalities are expressed in the "flattened" form and - // with ">= 0". I.e., lhs - rhs - 1 >= 0. - SmallVector eq(cstr.getNumCols(), 0); - auto addToEq = [&](OpFoldResult ofr, std::optional dim, - int64_t factor) { - if (auto constVal = ::getConstantIntValue(ofr)) { - eq[cstr.getNumCols() - 1] += *constVal * factor; - } else { - eq[getPos(cast(ofr), dim)] += factor; - } - }; - if (cmp == LT || cmp == LE) { - addToEq(lhs, lhsDim, 1); - addToEq(rhs, rhsDim, -1); - } else if (cmp == GT || cmp == GE) { - addToEq(lhs, lhsDim, -1); - addToEq(rhs, rhsDim, 1); - } else { - llvm_unreachable("unsupported comparison operator"); - } - if (cmp == LE || cmp == GE) - eq[cstr.getNumCols() - 1] -= 1; - - // Add inequality to the constraint set and check if it made the constraint - // set empty. - int64_t ineqPos = cstr.getNumInequalities(); - cstr.addInequality(eq); - bool isEmpty = cstr.isEmpty(); - cstr.removeInequality(ineqPos); - return isEmpty; -} - -bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, - ComparisonOperator cmp, - int64_t rhsPos) { - // This function returns "true" if "lhs CMP rhs" is proven to hold. For - // detailed documentation, see `compareValueDims`. - // EQ can be expressed as LE and GE. if (cmp == EQ) return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && @@ -712,48 +727,17 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, return isEmpty; } -bool ValueBoundsConstraintSet::populateAndCompare( - OpFoldResult lhs, std::optional lhsDim, ComparisonOperator cmp, - OpFoldResult rhs, std::optional rhsDim) { -#ifndef NDEBUG - if (auto lhsVal = dyn_cast(lhs)) - assertValidValueDim(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast(rhs)) - assertValidValueDim(rhsVal, rhsDim); -#endif // NDEBUG - - if (auto lhsVal = dyn_cast(lhs)) - populateConstraints(lhsVal, lhsDim); - if (auto rhsVal = dyn_cast(rhs)) - populateConstraints(rhsVal, rhsDim); - - return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); +bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs, + ComparisonOperator cmp, + const Variable &rhs) { + int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands); + int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands); + return comparePos(lhsPos, cmp, rhsPos); } -bool ValueBoundsConstraintSet::compare(OpFoldResult lhs, - std::optional lhsDim, - ComparisonOperator cmp, OpFoldResult rhs, - std::optional rhsDim) { - auto stopCondition = [&](Value v, std::optional dim, - ValueBoundsConstraintSet &cstr) { - // Keep processing as long as lhs/rhs are not mapped. - if (auto lhsVal = dyn_cast(lhs)) - if (!cstr.isMapped(lhsVal, dim)) - return false; - if (auto rhsVal = dyn_cast(rhs)) - if (!cstr.isMapped(rhsVal, dim)) - return false; - // Keep processing as long as the relation cannot be proven. - return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); - }; - - ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); - return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim); -} - -bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ValueDimList rhsOperands) { +bool ValueBoundsConstraintSet::compare(const Variable &lhs, + ComparisonOperator cmp, + const Variable &rhs) { int64_t lhsPos = -1, rhsPos = -1; auto stopCondition = [&](Value v, std::optional dim, ValueBoundsConstraintSet &cstr) { @@ -765,39 +749,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, return cstr.comparePos(lhsPos, cmp, rhsPos); }; ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); - lhsPos = cstr.insert(lhs, lhsOperands); - rhsPos = cstr.insert(rhs, rhsOperands); - cstr.processWorklist(); + lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); + rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands); return cstr.comparePos(lhsPos, cmp, rhsPos); } -bool ValueBoundsConstraintSet::compare(AffineMap lhs, - ArrayRef lhsOperands, - ComparisonOperator cmp, AffineMap rhs, - ArrayRef rhsOperands) { - ValueDimList lhsValueDimOperands = - llvm::map_to_vector(lhsOperands, [](Value v) { - return std::make_pair(v, std::optional()); - }); - ValueDimList rhsValueDimOperands = - llvm::map_to_vector(rhsOperands, [](Value v) { - return std::make_pair(v, std::optional()); - }); - return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs, - rhsValueDimOperands); -} - -FailureOr -ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2, - std::optional dim1, - std::optional dim2) { - if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ, - value2, dim2)) +FailureOr ValueBoundsConstraintSet::areEqual(const Variable &var1, + const Variable &var2) { + if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2)) return true; - if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT, - value2, dim2) || - ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT, - value2, dim2)) + if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) || + ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2)) return false; return failure(); } @@ -833,7 +795,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, AffineMap foldedMap = foldAttributesIntoMap(b, map, ofrOperands, valueOperands); FailureOr constBound = computeConstantBound( - presburger::BoundType::EQ, foldedMap, valueOperands); + presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); foundUnknownBound |= failed(constBound); if (succeeded(constBound) && *constBound <= 0) return false; @@ -850,7 +812,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, AffineMap foldedMap = foldAttributesIntoMap(b, map, ofrOperands, valueOperands); FailureOr constBound = computeConstantBound( - presburger::BoundType::EQ, foldedMap, valueOperands); + presburger::BoundType::EQ, Variable(foldedMap, valueOperands)); foundUnknownBound |= failed(constBound); if (succeeded(constBound) && *constBound <= 0) return false; diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 23c6872dcebe9..935c08aceff54 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -131,3 +131,27 @@ func.func @compare_affine_min(%a: index, %b: index) { "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> () return } + +// ----- + +func.func @compare_const_map() { + %c5 = arith.constant 5 : index + // expected-remark @below{{true}} + "test.compare"(%c5) {cmp = "GT", rhs_map = affine_map<() -> (4)>} + : (index) -> () + // expected-remark @below{{true}} + "test.compare"(%c5) {cmp = "LT", lhs_map = affine_map<() -> (4)>} + : (index) -> () + return +} + +// ----- + +func.func @compare_maps(%a: index, %b: index) { + // expected-remark @below{{true}} + "test.compare"(%a, %b, %b, %a) + {cmp = "GT", lhs_map = affine_map<(d0, d1) -> (1 + d0 + d1)>, + rhs_map = affine_map<(d0, d1) -> (d0 + d1)>} + : (index, index, index, index) -> () + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 6730f9b292ad9..b098a5a23fd31 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, FailureOr reified = failure(); if (constant) { auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( - boundType, value, dim, /*stopCondition=*/nullptr); + boundType, {value, dim}, /*stopCondition=*/nullptr); if (succeeded(reifiedConst)) reified = FailureOr(rewriter.getIndexAttr(*reifiedConst)); } else if (scalable) { @@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, rewriter, loc, reifiedScalable->map, vscaleOperand); } } else { - if (dim) { - if (useArithOps) { - reified = arith::reifyShapedValueDimBound( - rewriter, op->getLoc(), boundType, value, *dim, stopCondition); - } else { - reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType, - value, *dim, stopCondition); - } + if (useArithOps) { + reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType, + op.getVariable(), stopCondition); } else { - if (useArithOps) { - reified = arith::reifyIndexValueBound( - rewriter, op->getLoc(), boundType, value, stopCondition); - } else { - reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType, - value, stopCondition); - } + reified = reifyValueBound(rewriter, op->getLoc(), boundType, + op.getVariable(), stopCondition); } } if (failed(reified)) { @@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) { } auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) { - return ValueBoundsConstraintSet::compare( - /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp, - /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt); + return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs()); }; if (compare(cmpType)) { op->emitRemark("true"); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 25c5190ca0ef3..36d7606fe1345 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() { return success(); } +::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { + if (getDim().has_value()) + return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); + return ValueBoundsConstraintSet::Variable(getVar()); +} + ::mlir::ValueBoundsConstraintSet::ComparisonOperator CompareOp::getComparisonOperator() { if (getCmp() == "EQ") @@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() { llvm_unreachable("invalid comparison operator"); } +::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { + if (!getLhsMap()) + return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); + SmallVector mapOperands( + getVarOperands().slice(0, getLhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); +} + +::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { + int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + if (!getRhsMap()) + return ValueBoundsConstraintSet::Variable( + getVarOperands()[rhsOperandsBegin]); + SmallVector mapOperands( + getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); +} + +LogicalResult CompareOp::verify() { + if (getCompose() && (getLhsMap() || getRhsMap())) + return emitOpError( + "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); + int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; + if (getVarOperands().size() != expectedNumOperands) + return emitOpError("expected ") + << expectedNumOperands << " operands, but got " + << getVarOperands().size(); + return success(); +} + //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ebf158b8bb820..b641b3da719c7 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> { let extraClassDeclaration = [{ ::mlir::presburger::BoundType getBoundType(); + ::mlir::ValueBoundsConstraintSet::Variable getVariable(); }]; let hasVerifier = 1; @@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> { Compare `lhs` and `rhs`. A remark is emitted which indicates whether the specified comparison operator was proven to hold. The remark also indicates whether the opposite comparison operator was proven to hold. + + `var_operands` must have exactly two operands: one for the LHS operand and + one for the RHS operand. If `lhs_map` is specified, as many operands as + `lhs_map` has inputs are expected instead of the first operand. If `rhs_map` + is specified, as many operands as `rhs_map` has inputs are expected instead + of the second operand. }]; - let arguments = (ins Index:$lhs, - Index:$rhs, + let arguments = (ins Variadic:$var_operands, DefaultValuedAttr:$cmp, + OptionalAttr:$lhs_map, + OptionalAttr:$rhs_map, UnitAttr:$compose); let results = (outs); let extraClassDeclaration = [{ ::mlir::ValueBoundsConstraintSet::ComparisonOperator getComparisonOperator(); + ::mlir::ValueBoundsConstraintSet::Variable getLhs(); + ::mlir::ValueBoundsConstraintSet::Variable getRhs(); }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===//