Skip to content

Commit

Permalink
[mlir][affine][NFC] Extract core functionality of canonicalizeMinMaxOp
Browse files Browse the repository at this point in the history
Move code from SCF to Affine: Add a new helper function `simplifyConstrainedMinMaxOp` to Affine/Analysis/Utils.h. `canonicalizeMinMaxOp` was originally designed for loop peeling, but it is not SCF-specific and can be used to simplify any affine.min/max ops.

Various functions in SCF/Transforms are simplified by dropping unnecessary parameters.

Differential Revision: https://reviews.llvm.org/D140962
  • Loading branch information
matthias-springer committed Jan 4, 2023
1 parent 5bedd67 commit 3a5811a
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 218 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
namespace mlir {

class AffineForOp;
class AffineValueMap;
class Block;
class Location;
struct MemRefAccess;
Expand Down Expand Up @@ -384,6 +385,13 @@ unsigned getInnermostCommonLoopDepth(
ArrayRef<Operation *> ops,
SmallVectorImpl<AffineForOp> *surroundingLoops = nullptr);

/// Try to simplify the given affine.min or affine.max op to an affine map with
/// a single result and operands, taking into account the specified constraint
/// set. Return failure if no simplified version could be found.
FailureOr<AffineValueMap>
simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints);

} // namespace mlir

#endif // MLIR_DIALECT_AFFINE_ANALYSIS_UTILS_H
64 changes: 8 additions & 56 deletions mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,70 +39,23 @@ class IfOp;
using LoopMatcherFn = function_ref<LogicalResult(
Value, OpFoldResult &, OpFoldResult &, OpFoldResult &)>;

/// Try to canonicalize an min/max operations in the context of for `loops` with
/// a known range.
/// Try to canonicalize the given affine.min/max operation in the context of
/// for `loops` with a known range.
///
/// `map` is the body of the min/max operation and `operands` are the SSA values
/// that the dimensions and symbols are bound to; dimensions are listed first.
/// If `isMin`, the operation is a min operation; otherwise, a max operation.
/// `loopMatcher` is used to retrieve loop bounds and the step size for a given
/// iteration variable.
///
/// Note: `loopMatcher` allows this function to be used with any "for loop"-like
/// operation (scf.for, scf.parallel and even ops defined in other dialects).
LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
AffineMap map, ValueRange operands,
bool isMin, LoopMatcherFn loopMatcher);
LoopMatcherFn loopMatcher);

/// Attempt to canonicalize min/max operations by proving that their value is
/// bounded by the same lower and upper bound. In such cases, the operation can
/// be folded away.
///
/// Bounds are computed by FlatAffineValueConstraints. Invariants required for
/// finding/proving bounds should be supplied via `constraints`.
///
/// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
/// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
/// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
/// `op` but are not part of `constraints`, are added as extra symbols.
/// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
/// * If `isMin`: r_i >= opBound
/// * If `isMax`: r_i <= opBound
/// If this is the case, ub(op) == lb(op).
/// 4. Replace `op` with `opBound`.
///
/// In summary, the following constraints are added throughout this function.
/// Note: `invar` are dimensions added by the caller to express the invariants.
/// (Showing only the case where `isMin`.)
///
/// invar | op | opBound | r_i | extra syms... | const | eq/ineq
/// ------+-------+---------+-----+---------------+-------+-------------------
/// (various eq./ineq. constraining `invar`, added by the caller)
/// ... | 0 | 0 | 0 | 0 | ... | ...
/// ------+-------+---------+-----+---------------+-------+-------------------
/// (various ineq. constraining `op` in terms of `op` operands (`invar` and
/// extra `op` operands "extra syms" that are not in `invar`)).
/// ... | -1 | 0 | 0 | ... | ... | >= 0
/// ------+-------+---------+-----+---------------+-------+-------------------
/// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
/// ... | 0 | -1 | 0 | ... | ... | = 0
/// ------+-------+---------+-----+---------------+-------+-------------------
/// (for each `op` map result r_i: set r_i to corresponding map result,
/// prove that r_i >= minOpUb via contradiction)
/// ... | 0 | 0 | -1 | ... | ... | = 0
/// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
///
FailureOr<AffineApplyOp>
canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
ValueRange operands, bool isMin,
FlatAffineValueConstraints constraints);

/// Try to simplify a min/max operation `op` after loop peeling. This function
/// can simplify min/max operations such as (ub is the previous upper bound of
/// the unpeeled loop):
/// Try to simplify the given affine.min/max operation `op` after loop peeling.
/// This function can simplify min/max operations such as (ub is the previous
/// upper bound of the unpeeled loop):
/// ```
/// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
/// %r = affine.min #affine.min #map(%iv)[%step, %ub]
/// %r = affine.min #map(%iv)[%step, %ub]
/// ```
/// and rewrites them into (in the case the peeled loop):
/// ```
Expand All @@ -111,8 +64,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
/// min/max operations inside the partial iteration are rewritten in a similar
/// way.
LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
AffineMap map, ValueRange operands,
bool isMin, Value iv, Value ub, Value step,
Value iv, Value ub, Value step,
bool insideLoop);

} // namespace scf
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineAnalysis
MLIRAnalysis
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRDialectUtils
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
MLIRPresburger
Expand Down
182 changes: 182 additions & 0 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -1362,3 +1363,184 @@ IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
assert(simplifiedSet && "guaranteed to succeed while roundtripping");
return simplifiedSet;
}

static void unpackOptionalValues(ArrayRef<Optional<Value>> source,
SmallVector<Value> &target) {
target = llvm::to_vector<4>(llvm::map_range(source, [](Optional<Value> val) {
return val.has_value() ? *val : Value();
}));
}

/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
/// constraints drawn from an affine map. Before adding the constraint, the
/// dimensions/symbols of the affine map are aligned with `constraints`.
/// `operands` are the SSA Value operands used with the affine map.
/// Note: This function adds a new symbol column to the `constraints` for each
/// dimension/symbol that exists in the affine map but not in `constraints`.
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
IntegerPolyhedron::BoundType type,
unsigned pos, AffineMap map,
ValueRange operands) {
SmallVector<Value> dims, syms, newSyms;
unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);

AffineMap alignedMap =
alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
for (unsigned i = syms.size(); i < newSyms.size(); ++i)
constraints.appendSymbolVar(newSyms[i]);
return constraints.addBound(type, pos, alignedMap);
}

/// Add `val` to each result of `map`.
static AffineMap addConstToResults(AffineMap map, int64_t val) {
SmallVector<AffineExpr> newResults;
for (AffineExpr r : map.getResults())
newResults.push_back(r + val);
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
map.getContext());
}

// Attempt to simplify the given min/max operation by proving that its value is
// bounded by the same lower and upper bound.
//
// Bounds are computed by FlatAffineValueConstraints. Invariants required for
// finding/proving bounds should be supplied via `constraints`.
//
// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
// `op` but are not part of `constraints`, are added as extra symbols.
// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
// * If `isMin`: r_i >= opBound
// * If `isMax`: r_i <= opBound
// If this is the case, ub(op) == lb(op).
// 4. Replace `op` with `opBound`.
//
// In summary, the following constraints are added throughout this function.
// Note: `invar` are dimensions added by the caller to express the invariants.
// (Showing only the case where `isMin`.)
//
// invar | op | opBound | r_i | extra syms... | const | eq/ineq
// ------+-------+---------+-----+---------------+-------+-------------------
// (various eq./ineq. constraining `invar`, added by the caller)
// ... | 0 | 0 | 0 | 0 | ... | ...
// ------+-------+---------+-----+---------------+-------+-------------------
// (various ineq. constraining `op` in terms of `op` operands (`invar` and
// extra `op` operands "extra syms" that are not in `invar`)).
// ... | -1 | 0 | 0 | ... | ... | >= 0
// ------+-------+---------+-----+---------------+-------+-------------------
// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
// ... | 0 | -1 | 0 | ... | ... | = 0
// ------+-------+---------+-----+---------------+-------+-------------------
// (for each `op` map result r_i: set r_i to corresponding map result,
// prove that r_i >= minOpUb via contradiction)
// ... | 0 | 0 | -1 | ... | ... | = 0
// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
//
FailureOr<AffineValueMap>
mlir::simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints) {
bool isMin = isa<AffineMinOp>(op);
assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
MLIRContext *ctx = op->getContext();
Builder builder(ctx);
AffineMap map =
isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
ValueRange operands = op->getOperands();
unsigned numResults = map.getNumResults();

// Add a few extra dimensions.
unsigned dimOp = constraints.appendDimVar(); // `op`
unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);

// Add an inequality for each result expr_i of map:
// isMin: op <= expr_i, !isMin: op >= expr_i
auto boundType = isMin ? IntegerPolyhedron::UB : IntegerPolyhedron::LB;
// Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
if (failed(
alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
return failure();

// Try to compute a lower/upper bound for op, expressed in terms of the other
// `dims` and extra symbols.
SmallVector<AffineMap> opLb(1), opUb(1);
constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
// TODO: `getSliceBounds` may return multiple bounds at the moment. This is
// a TODO of `getSliceBounds` and not handled here.
if (!sliceBound || sliceBound.getNumResults() != 1)
return failure(); // No or multiple bounds found.
// Recover the inclusive UB in the case of an `affine.min`.
AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;

// Add an equality: Set dimOpBound to computed bound.
// Add back dimension for op. (Was removed by `getSliceBounds`.)
AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
if (failed(constraints.addBound(IntegerPolyhedron::EQ, dimOpBound,
alignedBoundMap)))
return failure();

// If the constraint system is empty, there is an inconsistency. (E.g., this
// can happen if loop lb > ub.)
if (constraints.isEmpty())
return failure();

// In the case of `isMin` (`!isMin` is inversed):
// Prove that each result of `map` has a lower bound that is equal to (or
// greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
// can be replaced with the bound. I.e., prove that for each result
// expr_i (represented by dimension r_i):
//
// r_i >= opBound
//
// To prove this inequality, add its negation to the constraint set and prove
// that the constraint set is empty.
for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
FlatAffineValueConstraints newConstr(constraints);

// Add an equality: r_i = expr_i
// Note: These equalities could have been added earlier and used to express
// minOp <= expr_i. However, then we run the risk that `getSliceBounds`
// computes minOpUb in terms of r_i dims, which is not desired.
if (failed(alignAndAddBound(newConstr, IntegerPolyhedron::EQ, i,
map.getSubMap({i - resultDimStart}), operands)))
return failure();

// If `isMin`: Add inequality: r_i < opBound
// equiv.: opBound - r_i - 1 >= 0
// If `!isMin`: Add inequality: r_i > opBound
// equiv.: -opBound + r_i - 1 >= 0
SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
ineq[dimOpBound] = isMin ? 1 : -1;
ineq[i] = isMin ? -1 : 1;
ineq[newConstr.getNumCols() - 1] = -1;
newConstr.addInequality(ineq);
if (!newConstr.isEmpty())
return failure();
}

// Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
AffineMap newMap = alignedBoundMap;
SmallVector<Value> newOperands;
unpackOptionalValues(constraints.getMaybeValues(), newOperands);
// If dims/symbols have known constant values, use those in order to simplify
// the affine map further.
for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) {
// Skip unused operands and operands that are already constants.
if (!newOperands[i] || getConstantIntValue(newOperands[i]))
continue;
if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i)) {
AffineExpr expr =
i < newMap.getNumDims()
? builder.getAffineDimExpr(i)
: builder.getAffineSymbolExpr(i - newMap.getNumDims());
newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
newMap.getNumDims(), newMap.getNumSymbols());
}
}
mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
return AffineValueMap(newMap, newOperands);
}
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {

/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
/// and scf.parallel loops with a known range.
template <typename OpTy, bool IsMin>
template <typename OpTy>
struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

Expand Down Expand Up @@ -192,8 +192,7 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
return failure();
};

return scf::canonicalizeMinMaxOpInLoop(
rewriter, op, op.getAffineMap(), op.getOperands(), IsMin, loopMatcher);
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher);
}
};

Expand All @@ -214,8 +213,8 @@ void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns
.add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
.add<AffineOpSCFCanonicalizationPattern<AffineMinOp>,
AffineOpSCFCanonicalizationPattern<AffineMaxOp>,
DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
DimOfLoopResultFolder<tensor::DimOp>,
DimOfLoopResultFolder<memref::DimOp>>(ctx);
Expand Down
28 changes: 13 additions & 15 deletions mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
return success();
}

template <typename OpTy, bool IsMin>
static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
ForOp partialIteration,
Value previousUb) {
Expand All @@ -164,18 +163,20 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
"expected same step in main and partial loop");
Value step = forOp.getStep();

forOp.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.getOperands(), IsMin, mainIv,
previousUb, step,
forOp.walk([&](Operation *affineOp) {
if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
return WalkResult::advance();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb,
step,
/*insideLoop=*/true);
return WalkResult::advance();
});
partialIteration.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.getOperands(), IsMin, partialIv,
previousUb, step, /*insideLoop=*/false);
partialIteration.walk([&](Operation *affineOp) {
if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
return WalkResult::advance();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
step, /*insideLoop=*/false);
return WalkResult::advance();
});
}

Expand All @@ -188,10 +189,7 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
return failure();

// Rewrite affine.min and affine.max ops.
rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
rewriter, forOp, partialIteration, previousUb);
rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
rewriter, forOp, partialIteration, previousUb);
rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb);

return success();
}
Expand Down
Loading

0 comments on commit 3a5811a

Please sign in to comment.