Skip to content

Commit 266a741

Browse files
committed
[mlir][sparse] move tensor expression builder into Merger utility
Rationale: Follow-up on migrating lattice and tensor expression related methods into the new utility. This also prepares the next step of generalizing the op kinds that are handled. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D105219
1 parent 8c7349b commit 266a741

File tree

4 files changed

+104
-84
lines changed

4 files changed

+104
-84
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
1414
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
1515

16+
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
1617
#include "mlir/IR/Value.h"
1718
#include "llvm/ADT/BitVector.h"
1819

@@ -148,11 +149,6 @@ class Merger {
148149
/// Returns true if any set bit corresponds to queried dim.
149150
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
150151

151-
/// Builds the iteration lattices in a bottom-up traversal given the remaining
152-
/// tensor (sub)expression and the next loop index in the iteration graph.
153-
/// Returns index of the root expression.
154-
unsigned buildLattices(unsigned exp, unsigned idx);
155-
156152
/// Setter
157153
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
158154

@@ -169,7 +165,19 @@ class Merger {
169165
void dumpBits(const llvm::BitVector &bits) const;
170166
#endif
171167

168+
/// Builds the iteration lattices in a bottom-up traversal given the remaining
169+
/// tensor (sub)expression and the next loop index in the iteration graph.
170+
/// Returns index of the root expression.
171+
unsigned buildLattices(unsigned exp, unsigned idx);
172+
173+
/// Builds a tensor expression from the given Linalg operation.
174+
/// Returns index of the root expression on success.
175+
Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
176+
172177
private:
178+
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
179+
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
180+
173181
const unsigned outTensor;
174182
const unsigned syntheticTensor;
175183
const unsigned numTensors;

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -208,51 +208,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
208208
return true;
209209
}
210210

211-
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
212-
/// This simplifies constructing (sub)expressions during iteration lattice
213-
/// building (compared to using the SSA representation everywhere).
214-
static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
215-
Value val) {
216-
if (auto arg = val.dyn_cast<BlockArgument>()) {
217-
unsigned argN = arg.getArgNumber();
218-
// Any argument of the generic op that is not marked as a scalar
219-
// argument is considered a tensor, indexed by the implicit loop
220-
// bounds. This includes rank-0 tensor arguments.
221-
if (arg.getOwner()->getParentOp() == op) {
222-
OpOperand *t = op.getInputAndOutputOperands()[argN];
223-
if (!op.isScalar(t))
224-
return merger.addExp(Kind::kTensor, argN);
225-
val = t->get(); // get scalar value
226-
}
227-
// Any other argument (marked as scalar argument for the generic op
228-
// or belonging to an enveloping op) is considered invariant.
229-
return merger.addExp(Kind::kInvariant, val);
230-
}
231-
Operation *def = val.getDefiningOp();
232-
if (def->getBlock() != &op.region().front()) {
233-
// Something defined outside is invariant.
234-
return merger.addExp(Kind::kInvariant, val);
235-
} else if (def->getNumOperands() == 2) {
236-
// Construct binary operations if subexpressions could be built.
237-
auto x = buildTensorExp(merger, op, def->getOperand(0));
238-
auto y = buildTensorExp(merger, op, def->getOperand(1));
239-
if (x.hasValue() && y.hasValue()) {
240-
unsigned e0 = x.getValue();
241-
unsigned e1 = y.getValue();
242-
if (isa<MulFOp>(def))
243-
return merger.addExp(Kind::kMulF, e0, e1);
244-
if (isa<MulIOp>(def))
245-
return merger.addExp(Kind::kMulI, e0, e1);
246-
if (isa<AddFOp>(def))
247-
return merger.addExp(Kind::kAddF, e0, e1);
248-
if (isa<AddIOp>(def))
249-
return merger.addExp(Kind::kAddI, e0, e1);
250-
}
251-
}
252-
// Cannot build (yet).
253-
return None;
254-
}
255-
256211
/// Returns true if given tensor co-iterates with conjunction only.
257212
/// For the output tensor, this defines a "simply dynamic" operation.
258213
/// For instance: A(I) = A(I) * B(I) * C(I)
@@ -1224,14 +1179,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
12241179
!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
12251180
return failure();
12261181

1227-
// Finds the terminating yield statement and builds the tensor
1228-
// expression for the Linalg operation in SSA form.
1229-
Operation *yield = op.region().front().getTerminator();
1230-
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1182+
// Builds the tensor expression for the Linalg operation in SSA form.
1183+
Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op);
12311184
if (!exp.hasValue())
1232-
return failure(); // build failure
1185+
return failure();
12331186

1234-
// Reject an inadmissable tensor expression.
1187+
// Rejects an inadmissable tensor expression.
12351188
if (!isAdmissableTensorExp(merger, op, exp.getValue()))
12361189
return failure();
12371190

mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
66

77
LINK_LIBS PUBLIC
88
MLIRIR
9+
MLIRLinalg
910
)

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
namespace mlir {
1515
namespace sparse_tensor {
1616

17+
//
18+
// Lattice methods.
19+
//
20+
1721
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
1822
unsigned e = tensorExps.size();
1923
tensorExps.push_back(TensorExp(k, e0, e1, v));
@@ -68,7 +72,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
6872
if (p0 != p1) {
6973
// Is this a straightforward copy?
7074
unsigned e = latPoints[p1].exp;
71-
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
75+
if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
7276
continue;
7377
// Conjunction already covered?
7478
for (unsigned p2 : latSets[s]) {
@@ -137,33 +141,6 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
137141
return false;
138142
}
139143

140-
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
141-
Kind kind = exp(e).kind;
142-
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
143-
// Either the index is really used in the tensor expression, or it is
144-
// set to the undefined index in that dimension. An invariant expression
145-
// is set to a synthetic tensor with undefined indices only.
146-
unsigned s = addSet();
147-
unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
148-
set(s).push_back(addLat(t, idx, e));
149-
return s;
150-
}
151-
unsigned s0 = buildLattices(exp(e).e0, idx);
152-
unsigned s1 = buildLattices(exp(e).e1, idx);
153-
switch (kind) {
154-
case Kind::kTensor:
155-
case Kind::kInvariant:
156-
llvm_unreachable("handled above");
157-
case Kind::kMulF:
158-
case Kind::kMulI:
159-
return takeConj(kind, s0, s1);
160-
case Kind::kAddF:
161-
case Kind::kAddI:
162-
return takeDisj(kind, s0, s1);
163-
}
164-
llvm_unreachable("unexpected expression kind");
165-
}
166-
167144
#ifndef NDEBUG
168145

169146
//
@@ -173,6 +150,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
173150
void Merger::dumpExp(unsigned e) const {
174151
switch (tensorExps[e].kind) {
175152
case Kind::kTensor:
153+
if (tensorExps[e].e0 == syntheticTensor)
154+
llvm::dbgs() << "synthetic_";
155+
else if (tensorExps[e].e0 == outTensor)
156+
llvm::dbgs() << "output_";
176157
llvm::dbgs() << "tensor_" << tensorExps[e].e0;
177158
break;
178159
case Kind::kInvariant:
@@ -242,5 +223,82 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
242223

243224
#endif // NDEBUG
244225

226+
//
227+
// Builder methods.
228+
//
229+
230+
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
231+
Kind kind = tensorExps[e].kind;
232+
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
233+
// Either the index is really used in the tensor expression, or it is
234+
// set to the undefined index in that dimension. An invariant expression
235+
// is set to a synthetic tensor with undefined indices only.
236+
unsigned s = addSet();
237+
unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
238+
latSets[s].push_back(addLat(t, idx, e));
239+
return s;
240+
}
241+
unsigned s0 = buildLattices(tensorExps[e].e0, idx);
242+
unsigned s1 = buildLattices(tensorExps[e].e1, idx);
243+
switch (kind) {
244+
case Kind::kTensor:
245+
case Kind::kInvariant:
246+
llvm_unreachable("handled above");
247+
case Kind::kMulF:
248+
case Kind::kMulI:
249+
return takeConj(kind, s0, s1);
250+
case Kind::kAddF:
251+
case Kind::kAddI:
252+
return takeDisj(kind, s0, s1);
253+
}
254+
llvm_unreachable("unexpected expression kind");
255+
}
256+
257+
Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
258+
Operation *yield = op.region().front().getTerminator();
259+
return buildTensorExp(op, yield->getOperand(0));
260+
}
261+
262+
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
263+
if (auto arg = val.dyn_cast<BlockArgument>()) {
264+
unsigned argN = arg.getArgNumber();
265+
// Any argument of the generic op that is not marked as a scalar
266+
// argument is considered a tensor, indexed by the implicit loop
267+
// bounds. This includes rank-0 tensor arguments.
268+
if (arg.getOwner()->getParentOp() == op) {
269+
OpOperand *t = op.getInputAndOutputOperands()[argN];
270+
if (!op.isScalar(t))
271+
return addExp(Kind::kTensor, argN);
272+
val = t->get(); // get scalar value
273+
}
274+
// Any other argument (marked as scalar argument for the generic op
275+
// or belonging to an enveloping op) is considered invariant.
276+
return addExp(Kind::kInvariant, val);
277+
}
278+
// Something defined outside is invariant.
279+
Operation *def = val.getDefiningOp();
280+
if (def->getBlock() != &op.region().front())
281+
return addExp(Kind::kInvariant, val);
282+
// Construct binary operations if subexpressions could be built.
283+
if (def->getNumOperands() == 2) {
284+
auto x = buildTensorExp(op, def->getOperand(0));
285+
auto y = buildTensorExp(op, def->getOperand(1));
286+
if (x.hasValue() && y.hasValue()) {
287+
unsigned e0 = x.getValue();
288+
unsigned e1 = y.getValue();
289+
if (isa<MulFOp>(def))
290+
return addExp(Kind::kMulF, e0, e1);
291+
if (isa<MulIOp>(def))
292+
return addExp(Kind::kMulI, e0, e1);
293+
if (isa<AddFOp>(def))
294+
return addExp(Kind::kAddF, e0, e1);
295+
if (isa<AddIOp>(def))
296+
return addExp(Kind::kAddI, e0, e1);
297+
}
298+
}
299+
// Cannot build.
300+
return None;
301+
}
302+
245303
} // namespace sparse_tensor
246304
} // namespace mlir

0 commit comments

Comments
 (0)