Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][linalg] Improve implementation of hoist padding.
Instead of relying on adhoc bounds calculations, use a projection-based implementation. This simplifies the implementation and finds more static constant sizes than previously/ Differential Revision: https://reviews.llvm.org/D106054
- Loading branch information
1 parent
5024fe9
commit 01bdb0f
Showing
5 changed files
with
335 additions
and
125 deletions.
There are no files selected for viewing
67 changes: 67 additions & 0 deletions
67
mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
//===- ConstraintsSet.h - Extensions for FlatAffineConstraints --*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Linalg-specific constraints set extensions. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_ | ||
#define MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_ | ||
|
||
#include "mlir/Analysis/AffineStructures.h" | ||
#include "mlir/IR/AffineMap.h" | ||
|
||
namespace mlir { | ||
class ValueRange; | ||
|
||
/// Linalg-specific constraints set extensions. | ||
class ConstraintsSet : public FlatAffineConstraints { | ||
public: | ||
ConstraintsSet() : FlatAffineConstraints() {} | ||
|
||
/// Assuming `val` is defined by `val = affine.min map (operands)`, introduce | ||
/// all the constraints `val >= expr_i(operands)`, where expr_i are all the | ||
/// results of `map`. | ||
// This API avoids taking a dependence on the AffineMinOp definition. | ||
LogicalResult composeMin(Value val, AffineMap map, ValueRange operands) { | ||
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/true); | ||
} | ||
|
||
/// Assuming `val` is defined by `val = affine.max map (operands)`, introduce | ||
/// all the constraints `val <= expr_i(operands)`, where expr_i are all the | ||
/// results of `map`. | ||
// This API avoids taking a dependence on the AffineMaxOp definition. | ||
LogicalResult composeMax(Value val, AffineMap map, ValueRange operands) { | ||
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/false); | ||
} | ||
|
||
/// Assuming `val` is defined by `val = affine.apply map (operands)`, call | ||
/// composeMap. | ||
// This API avoids taking a dependence on the AffineMApplyOp definition. | ||
LogicalResult composeAffineApply(Value val, AffineMap map, | ||
ValueRange operands); | ||
|
||
/// Asserts the identifier `id` is in the constraints set and returns it. | ||
unsigned lookupPos(Value id) const; | ||
|
||
/// If v is not in the constraint set, insert it as a dim or symbol depending | ||
/// on `asDim`. | ||
/// Return success if v is of dim id type when `asDim` is true and of symbol | ||
/// id type when `asDim` is false. | ||
/// Return failure otherwise. | ||
LogicalResult ensureIdOfType(Value v, bool asDim); | ||
|
||
private: | ||
/// Implementation detail for composeMin/Max. | ||
LogicalResult composeMinOrMaxMapAndOperands(Value val, AffineMap map, | ||
ValueRange operands, bool min); | ||
}; | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
add_mlir_dialect_library(MLIRLinalgAnalysis | ||
ConstraintsSet.cpp | ||
DependenceAnalysis.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRAnalysis | ||
MLIRIR | ||
MLIRLinalg | ||
MLIRLoopAnalysis | ||
MLIRMemRef | ||
MLIRStandard | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
//===- ConstraintsSet.cpp - Extensions for FlatAffineConstraints ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Linalg-specific constraints set extensions. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h" | ||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h" | ||
#include "mlir/IR/AffineMap.h" | ||
|
||
using namespace mlir; | ||
|
||
unsigned ConstraintsSet::lookupPos(Value id) const { | ||
unsigned pos; | ||
if (!findId(id, &pos)) { | ||
llvm::errs() << "Lookup failed: " << id << "\n"; | ||
llvm_unreachable("Lookup failed"); | ||
} | ||
return pos; | ||
} | ||
|
||
LogicalResult ConstraintsSet::ensureIdOfType(Value v, bool asDim) { | ||
if (!containsId(v)) { | ||
if (asDim) | ||
addDimId(getNumDimIds(), v); | ||
else | ||
addSymbolId(getNumSymbolIds(), v); | ||
return success(); | ||
} | ||
unsigned pos = lookupPos(v); | ||
return success((asDim && pos < getNumDimIds()) || | ||
(!asDim && getNumDimIds() <= pos && | ||
pos < getNumDimIds() + getNumSymbolIds())); | ||
} | ||
|
||
LogicalResult ConstraintsSet::composeAffineApply(Value val, AffineMap map, | ||
ValueRange operands) { | ||
AffineValueMap avm(map, operands, val); | ||
return composeMap(&avm); | ||
} | ||
|
||
LogicalResult ConstraintsSet::composeMinOrMaxMapAndOperands(Value val, | ||
AffineMap map, | ||
ValueRange operands, | ||
bool min) { | ||
ConstraintsSet localCst; | ||
std::vector<SmallVector<int64_t, 8>> flatExprs; | ||
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localCst))) | ||
return failure(); | ||
assert(flatExprs.size() == map.getNumResults() && | ||
"incorrect number of flattened expressiosn"); | ||
|
||
// Local vars on a per-need basis. | ||
if (localCst.getNumLocalIds() != 0) | ||
return failure(); | ||
|
||
// Add one inequality for each result connecting `val` to the other ids in | ||
// `operands`. For instance, uf the expression is: | ||
// `16 * i0 + i1` and | ||
// `min` is true | ||
// add: | ||
// -d_val + 16 * i0 + i1 >= 0. | ||
for (const auto &flatExpr : flatExprs) { | ||
assert(flatExpr.size() >= operands.size() + 1); | ||
SmallVector<int64_t, 8> ineq(getNumCols(), 0); | ||
for (unsigned i = 0, e = operands.size(); i < e; i++) | ||
ineq[lookupPos(operands[i])] = min ? flatExpr[i] : -flatExpr[i]; | ||
|
||
// Set the coefficient for `d_val`. | ||
ineq[lookupPos(val)] = min ? -1 : 1; | ||
|
||
// Set the constant term (upper bound in flatExpr is exclusive). | ||
ineq[getNumCols() - 1] = min ? flatExpr[flatExpr.size() - 1] - 1 | ||
: -flatExpr[flatExpr.size() - 1]; | ||
|
||
// Add the inequality connecting the result of the map to the rest. | ||
addInequality(ineq); | ||
} | ||
|
||
return success(); | ||
} |
Oops, something went wrong.