Skip to content

Commit 01bdb0f

Browse files
[mlir][linalg] Improve implementation of hoist padding.
Instead of relying on adhoc bounds calculations, use a projection-based implementation. This simplifies the implementation and finds more static constant sizes than previously/ Differential Revision: https://reviews.llvm.org/D106054
1 parent 5024fe9 commit 01bdb0f

File tree

5 files changed

+335
-125
lines changed

5 files changed

+335
-125
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//===- ConstraintsSet.h - Extensions for FlatAffineConstraints --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Linalg-specific constraints set extensions.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
14+
#define MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
15+
16+
#include "mlir/Analysis/AffineStructures.h"
17+
#include "mlir/IR/AffineMap.h"
18+
19+
namespace mlir {
20+
class ValueRange;
21+
22+
/// Linalg-specific constraints set extensions.
23+
class ConstraintsSet : public FlatAffineConstraints {
24+
public:
25+
ConstraintsSet() : FlatAffineConstraints() {}
26+
27+
/// Assuming `val` is defined by `val = affine.min map (operands)`, introduce
28+
/// all the constraints `val >= expr_i(operands)`, where expr_i are all the
29+
/// results of `map`.
30+
// This API avoids taking a dependence on the AffineMinOp definition.
31+
LogicalResult composeMin(Value val, AffineMap map, ValueRange operands) {
32+
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/true);
33+
}
34+
35+
/// Assuming `val` is defined by `val = affine.max map (operands)`, introduce
36+
/// all the constraints `val <= expr_i(operands)`, where expr_i are all the
37+
/// results of `map`.
38+
// This API avoids taking a dependence on the AffineMaxOp definition.
39+
LogicalResult composeMax(Value val, AffineMap map, ValueRange operands) {
40+
return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/false);
41+
}
42+
43+
/// Assuming `val` is defined by `val = affine.apply map (operands)`, call
44+
/// composeMap.
45+
// This API avoids taking a dependence on the AffineMApplyOp definition.
46+
LogicalResult composeAffineApply(Value val, AffineMap map,
47+
ValueRange operands);
48+
49+
/// Asserts the identifier `id` is in the constraints set and returns it.
50+
unsigned lookupPos(Value id) const;
51+
52+
/// If v is not in the constraint set, insert it as a dim or symbol depending
53+
/// on `asDim`.
54+
/// Return success if v is of dim id type when `asDim` is true and of symbol
55+
/// id type when `asDim` is false.
56+
/// Return failure otherwise.
57+
LogicalResult ensureIdOfType(Value v, bool asDim);
58+
59+
private:
60+
/// Implementation detail for composeMin/Max.
61+
LogicalResult composeMinOrMaxMapAndOperands(Value val, AffineMap map,
62+
ValueRange operands, bool min);
63+
};
64+
65+
} // namespace mlir
66+
67+
#endif // MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
add_mlir_dialect_library(MLIRLinalgAnalysis
2+
ConstraintsSet.cpp
23
DependenceAnalysis.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
67

78
LINK_LIBS PUBLIC
9+
MLIRAnalysis
810
MLIRIR
911
MLIRLinalg
12+
MLIRLoopAnalysis
1013
MLIRMemRef
1114
MLIRStandard
1215
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===- ConstraintsSet.cpp - Extensions for FlatAffineConstraints ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Linalg-specific constraints set extensions.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h"
14+
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
15+
#include "mlir/IR/AffineMap.h"
16+
17+
using namespace mlir;
18+
19+
unsigned ConstraintsSet::lookupPos(Value id) const {
20+
unsigned pos;
21+
if (!findId(id, &pos)) {
22+
llvm::errs() << "Lookup failed: " << id << "\n";
23+
llvm_unreachable("Lookup failed");
24+
}
25+
return pos;
26+
}
27+
28+
LogicalResult ConstraintsSet::ensureIdOfType(Value v, bool asDim) {
29+
if (!containsId(v)) {
30+
if (asDim)
31+
addDimId(getNumDimIds(), v);
32+
else
33+
addSymbolId(getNumSymbolIds(), v);
34+
return success();
35+
}
36+
unsigned pos = lookupPos(v);
37+
return success((asDim && pos < getNumDimIds()) ||
38+
(!asDim && getNumDimIds() <= pos &&
39+
pos < getNumDimIds() + getNumSymbolIds()));
40+
}
41+
42+
LogicalResult ConstraintsSet::composeAffineApply(Value val, AffineMap map,
43+
ValueRange operands) {
44+
AffineValueMap avm(map, operands, val);
45+
return composeMap(&avm);
46+
}
47+
48+
LogicalResult ConstraintsSet::composeMinOrMaxMapAndOperands(Value val,
49+
AffineMap map,
50+
ValueRange operands,
51+
bool min) {
52+
ConstraintsSet localCst;
53+
std::vector<SmallVector<int64_t, 8>> flatExprs;
54+
if (failed(getFlattenedAffineExprs(map, &flatExprs, &localCst)))
55+
return failure();
56+
assert(flatExprs.size() == map.getNumResults() &&
57+
"incorrect number of flattened expressiosn");
58+
59+
// Local vars on a per-need basis.
60+
if (localCst.getNumLocalIds() != 0)
61+
return failure();
62+
63+
// Add one inequality for each result connecting `val` to the other ids in
64+
// `operands`. For instance, uf the expression is:
65+
// `16 * i0 + i1` and
66+
// `min` is true
67+
// add:
68+
// -d_val + 16 * i0 + i1 >= 0.
69+
for (const auto &flatExpr : flatExprs) {
70+
assert(flatExpr.size() >= operands.size() + 1);
71+
SmallVector<int64_t, 8> ineq(getNumCols(), 0);
72+
for (unsigned i = 0, e = operands.size(); i < e; i++)
73+
ineq[lookupPos(operands[i])] = min ? flatExpr[i] : -flatExpr[i];
74+
75+
// Set the coefficient for `d_val`.
76+
ineq[lookupPos(val)] = min ? -1 : 1;
77+
78+
// Set the constant term (upper bound in flatExpr is exclusive).
79+
ineq[getNumCols() - 1] = min ? flatExpr[flatExpr.size() - 1] - 1
80+
: -flatExpr[flatExpr.size() - 1];
81+
82+
// Add the inequality connecting the result of the map to the rest.
83+
addInequality(ineq);
84+
}
85+
86+
return success();
87+
}

0 commit comments

Comments
 (0)