diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h index 31e19ff1ad39f..67a6581eb2fb4 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h +++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h @@ -29,9 +29,12 @@ struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet { struct ScalableValueBoundsConstraintSet : public llvm::RTTIExtends { - ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin, - unsigned vscaleMax) - : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){}; + ScalableValueBoundsConstraintSet( + MLIRContext *context, + ValueBoundsConstraintSet::StopConditionFn stopCondition, + unsigned vscaleMin, unsigned vscaleMax) + : RTTIExtends(context, stopCondition), vscaleMin(vscaleMin), + vscaleMax(vscaleMax) {}; using RTTIExtends::bound; using RTTIExtends::StopConditionFn; diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index bdfd689c7ac4f..83107a3f5f941 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -117,8 +117,9 @@ class ValueBoundsConstraintSet /// /// The first parameter of the function is the shaped value/index-typed /// value. The second parameter is the dimension in case of a shaped value. - using StopConditionFn = - function_ref /*dim*/)>; + /// The third parameter is this constraint set. + using StopConditionFn = std::function /*dim*/, ValueBoundsConstraintSet &cstr)>; /// Compute a bound for the given index-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are @@ -271,22 +272,20 @@ class ValueBoundsConstraintSet /// An index-typed value or the dimension of a shaped-type value. using ValueDim = std::pair; - ValueBoundsConstraintSet(MLIRContext *ctx); + ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); /// Populates the constraint set for a value/map without actually computing /// the bound. Returns the position for the value/map (via the return value /// and `posOut` output parameter). int64_t populateConstraintsSet(Value value, - std::optional dim = std::nullopt, - StopConditionFn stopCondition = nullptr); + std::optional dim = std::nullopt); int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands, - StopConditionFn stopCondition = nullptr, int64_t *posOut = nullptr); /// Iteratively process all elements on the worklist until an index-typed /// value or shaped value meets `stopCondition`. Such values are not processed /// any further. - void processWorklist(StopConditionFn stopCondition); + void processWorklist(); /// Bound the given column in the underlying constraint set by the given /// expression. @@ -333,6 +332,9 @@ class ValueBoundsConstraintSet /// Builder for constructing affine expressions. Builder builder; + + /// The current stop condition function. + StopConditionFn stopCondition = nullptr; }; } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 37b36f76d4465..117ee8e8701ad 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -84,7 +84,8 @@ FailureOr mlir::affine::reifyShapedValueDimBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional d) { + auto reifyToOperands = [&](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { // We are trying to reify a bound for `value` in terms of the owning op's // operands. Construct a stop condition that evaluates to "true" for any SSA // value except for `value`. I.e., the bound will be computed in terms of @@ -100,7 +101,8 @@ FailureOr mlir::affine::reifyShapedValueDimBound( FailureOr mlir::affine::reifyIndexValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional d) { + auto reifyToOperands = [&](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { return v != value; }; return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 8d9fd1478aa9e..fad221288f190 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -119,7 +119,8 @@ FailureOr mlir::arith::reifyShapedValueDimBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional d) { + auto reifyToOperands = [&](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { // We are trying to reify a bound for `value` in terms of the owning op's // operands. Construct a stop condition that evaluates to "true" for any SSA // value expect for `value`. I.e., the bound will be computed in terms of @@ -135,7 +136,8 @@ FailureOr mlir::arith::reifyShapedValueDimBound( FailureOr mlir::arith::reifyIndexValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional d) { + auto reifyToOperands = [&](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { return v != value; }; return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index b32ea8eebaecb..c3a08ce86082a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter, FailureOr loopUb = affine::reifyIndexValueBound( rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), /*stopCondition=*/ - [&](Value v, std::optional d) { + [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { if (v == forOp.getUpperBound()) return false; // Compute a bound that is independent of any affine op results. diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index cb36e0cecf0d2..1e13e60068ee7 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -58,7 +58,7 @@ struct ForOpInterface ValueDimList boundOperands; LogicalResult status = ValueBoundsConstraintSet::computeBound( bound, boundOperands, BoundType::EQ, yieldedValue, dim, - [&](Value v, std::optional d) { + [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { // Stop when reaching a block argument of the loop body. if (auto bbArg = llvm::dyn_cast(v)) return bbArg.getOwner()->getParentOp() == forOp; diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp index 6d7e3bc70f59d..52359fa8a510d 100644 --- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound( unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, StopConditionFn stopCondition) { using namespace presburger; - assert(vscaleMin <= vscaleMax); - ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin, - vscaleMax); - int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition); + // No stop condition specified: Keep adding constraints until the worklist + // is empty. + auto defaultStopCondition = [&](Value v, std::optional dim, + mlir::ValueBoundsConstraintSet &cstr) { + return false; + }; + + ScalableValueBoundsConstraintSet scalableCstr( + value.getContext(), stopCondition ? stopCondition : defaultStopCondition, + vscaleMin, vscaleMax); + int64_t pos = scalableCstr.populateConstraintsSet(value, dim); // Project out all variables apart from vscale. // This should result in constraints in terms of vscale only. - scalableCstr.projectOut( - [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); }); + auto projectOutFn = [&](ValueDim p) { + return p.first != scalableCstr.getVscaleValue(); + }; + scalableCstr.projectOut(projectOutFn); assert(scalableCstr.cstr.getNumDimAndSymbolVars() == scalableCstr.positionToValueDim.size() && diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 9a3185d55d6e8..0d362c7efa0a0 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -67,8 +67,11 @@ static std::optional getConstantIntValue(OpFoldResult ofr) { return std::nullopt; } -ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx) - : builder(ctx) {} +ValueBoundsConstraintSet::ValueBoundsConstraintSet( + MLIRContext *ctx, StopConditionFn stopCondition) + : builder(ctx), stopCondition(stopCondition) { + assert(stopCondition && "expected non-null stop condition"); +} char ValueBoundsConstraintSet::ID = 0; @@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) { return value.getDefiningOp(); } -void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { +void ValueBoundsConstraintSet::processWorklist() { + LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); @@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { // Do not process any further if the stop condition is met. auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); - if (stopCondition(value, maybeDim)) + if (stopCondition(value, maybeDim, *this)) { + LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value + << " (dim: " << maybeDim << ")\n"); continue; + } // Query `ValueBoundsOpInterface` for constraints. New items may be added to // the worklist. auto valueBoundsOp = dyn_cast(getOwnerOfValue(value)); + LLVM_DEBUG(llvm::dbgs() + << "Query value bounds for: " << value + << " (owner: " << getOwnerOfValue(value)->getName() << ")\n"); if (valueBoundsOp) { if (dim == kIndexValue) { valueBoundsOp.populateBoundsForIndexValue(value, *this); @@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { } continue; } + LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n"); // If the op does not implement `ValueBoundsOpInterface`, check if it // implements the `DestinationStyleOpInterface`. OpResults of such ops are @@ -278,8 +289,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound( bool closedUB) { #ifndef NDEBUG assertValidValueDim(value, dim); - assert(!stopCondition(value, dim) && - "stop condition should not be satisfied for starting point"); #endif // NDEBUG int64_t ubAdjustment = closedUB ? 0 : 1; @@ -289,9 +298,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound( // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - ValueBoundsConstraintSet cstr(value.getContext()); + ValueBoundsConstraintSet cstr(value.getContext(), stopCondition); + assert(!stopCondition(value, dim, cstr) && + "stop condition should not be satisfied for starting point"); int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); - cstr.processWorklist(stopCondition); + cstr.processWorklist(); // Project out all variables (apart from `valueDim`) that do not match the // stop condition. @@ -301,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound( return false; auto maybeDim = p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); - return !stopCondition(p.first, maybeDim); + return !stopCondition(p.first, maybeDim, cstr); }); // Compute lower and upper bounds for `valueDim`. @@ -407,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound( bool closedUB) { return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional d) { + [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, closedUB); @@ -443,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( // Reify bounds in terms of any independent values. return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional d) { return isIndependent(v); }, + [&](Value v, std::optional d, ValueBoundsConstraintSet &cstr) { + return isIndependent(v); + }, closedUB); } @@ -476,21 +489,19 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, AffineMap map, ValueDimList operands, StopConditionFn stopCondition, bool closedUB) { assert(map.getNumResults() == 1 && "expected affine map with one result"); - ValueBoundsConstraintSet cstr(map.getContext()); - int64_t pos = 0; - if (stopCondition) { - cstr.populateConstraintsSet(map, operands, stopCondition, &pos); - } else { - // No stop condition specified: Keep adding constraints until a bound could - // be computed. - cstr.populateConstraintsSet( - map, operands, - [&](Value v, std::optional dim) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }, - &pos); - } + // Default stop condition if none was specified: Keep adding constraints until + // a bound could be computed. + int64_t pos; + auto defaultStopCondition = [&](Value v, std::optional dim, + ValueBoundsConstraintSet &cstr) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }; + + ValueBoundsConstraintSet cstr( + map.getContext(), stopCondition ? stopCondition : defaultStopCondition); + cstr.populateConstraintsSet(map, operands, &pos); + // Compute constant bound for `valueDim`. int64_t ubAdjustment = closedUB ? 0 : 1; if (auto bound = cstr.cstr.getConstantBound64(type, pos)) @@ -498,8 +509,9 @@ FailureOr ValueBoundsConstraintSet::computeConstantBound( return failure(); } -int64_t ValueBoundsConstraintSet::populateConstraintsSet( - Value value, std::optional dim, StopConditionFn stopCondition) { +int64_t +ValueBoundsConstraintSet::populateConstraintsSet(Value value, + std::optional dim) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG @@ -507,12 +519,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet( AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, Builder(value.getContext()).getAffineDimExpr(0)); - return populateConstraintsSet(map, {{value, dim}}, stopCondition); + return populateConstraintsSet(map, {{value, dim}}); } -int64_t ValueBoundsConstraintSet::populateConstraintsSet( - AffineMap map, ValueDimList operands, StopConditionFn stopCondition, - int64_t *posOut) { +int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map, + ValueDimList operands, + int64_t *posOut) { assert(map.getNumResults() == 1 && "expected affine map with one result"); int64_t pos = insert(/*isSymbol=*/false); if (posOut) @@ -533,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet( // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. - if (stopCondition) { - processWorklist(stopCondition); - } else { - // No stop condition specified: Keep adding constraints until the worklist - // is empty. - processWorklist([](Value v, std::optional dim) { return false; }); - } + processWorklist(); return pos; } diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 5e160b720db62..4b2b1a06341b7 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -117,14 +117,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. - std::function)> stopCondition = - [&](Value v, std::optional d) { + std::function, + ValueBoundsConstraintSet & cstr)> + stopCondition = [&](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { // Reify in terms of SSA values that are different from `value`. return v != value; }; if (reifyToFuncArgs) { // Reify in terms of function block arguments. - stopCondition = stopCondition = [](Value v, std::optional d) { + stopCondition = [](Value v, std::optional d, + ValueBoundsConstraintSet &cstr) { auto bbArg = dyn_cast(v); if (!bbArg) return false;