14
14
namespace mlir {
15
15
namespace sparse_tensor {
16
16
17
+ //
18
+ // Lattice methods.
19
+ //
20
+
17
21
unsigned Merger::addExp (Kind k, unsigned e0 , unsigned e1 , Value v) {
18
22
unsigned e = tensorExps.size ();
19
23
tensorExps.push_back (TensorExp (k, e0 , e1 , v));
@@ -68,7 +72,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
68
72
if (p0 != p1) {
69
73
// Is this a straightforward copy?
70
74
unsigned e = latPoints[p1].exp ;
71
- if (exp (e) .kind == Kind::kTensor && exp (e) .e0 == outTensor)
75
+ if (tensorExps[e] .kind == Kind::kTensor && tensorExps[e] .e0 == outTensor)
72
76
continue ;
73
77
// Conjunction already covered?
74
78
for (unsigned p2 : latSets[s]) {
@@ -137,33 +141,6 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
137
141
return false ;
138
142
}
139
143
140
- unsigned Merger::buildLattices (unsigned e, unsigned idx) {
141
- Kind kind = exp (e).kind ;
142
- if (kind == Kind::kTensor || kind == Kind::kInvariant ) {
143
- // Either the index is really used in the tensor expression, or it is
144
- // set to the undefined index in that dimension. An invariant expression
145
- // is set to a synthetic tensor with undefined indices only.
146
- unsigned s = addSet ();
147
- unsigned t = kind == Kind::kTensor ? exp (e).e0 : syntheticTensor;
148
- set (s).push_back (addLat (t, idx, e));
149
- return s;
150
- }
151
- unsigned s0 = buildLattices (exp (e).e0 , idx);
152
- unsigned s1 = buildLattices (exp (e).e1 , idx);
153
- switch (kind) {
154
- case Kind::kTensor :
155
- case Kind::kInvariant :
156
- llvm_unreachable (" handled above" );
157
- case Kind::kMulF :
158
- case Kind::kMulI :
159
- return takeConj (kind, s0, s1);
160
- case Kind::kAddF :
161
- case Kind::kAddI :
162
- return takeDisj (kind, s0, s1);
163
- }
164
- llvm_unreachable (" unexpected expression kind" );
165
- }
166
-
167
144
#ifndef NDEBUG
168
145
169
146
//
@@ -173,6 +150,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
173
150
void Merger::dumpExp (unsigned e) const {
174
151
switch (tensorExps[e].kind ) {
175
152
case Kind::kTensor :
153
+ if (tensorExps[e].e0 == syntheticTensor)
154
+ llvm::dbgs () << " synthetic_" ;
155
+ else if (tensorExps[e].e0 == outTensor)
156
+ llvm::dbgs () << " output_" ;
176
157
llvm::dbgs () << " tensor_" << tensorExps[e].e0 ;
177
158
break ;
178
159
case Kind::kInvariant :
@@ -242,5 +223,82 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
242
223
243
224
#endif // NDEBUG
244
225
226
+ //
227
+ // Builder methods.
228
+ //
229
+
230
+ unsigned Merger::buildLattices (unsigned e, unsigned idx) {
231
+ Kind kind = tensorExps[e].kind ;
232
+ if (kind == Kind::kTensor || kind == Kind::kInvariant ) {
233
+ // Either the index is really used in the tensor expression, or it is
234
+ // set to the undefined index in that dimension. An invariant expression
235
+ // is set to a synthetic tensor with undefined indices only.
236
+ unsigned s = addSet ();
237
+ unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
238
+ latSets[s].push_back (addLat (t, idx, e));
239
+ return s;
240
+ }
241
+ unsigned s0 = buildLattices (tensorExps[e].e0 , idx);
242
+ unsigned s1 = buildLattices (tensorExps[e].e1 , idx);
243
+ switch (kind) {
244
+ case Kind::kTensor :
245
+ case Kind::kInvariant :
246
+ llvm_unreachable (" handled above" );
247
+ case Kind::kMulF :
248
+ case Kind::kMulI :
249
+ return takeConj (kind, s0, s1);
250
+ case Kind::kAddF :
251
+ case Kind::kAddI :
252
+ return takeDisj (kind, s0, s1);
253
+ }
254
+ llvm_unreachable (" unexpected expression kind" );
255
+ }
256
+
257
+ Optional<unsigned > Merger::buildTensorExpFromLinalg (linalg::GenericOp op) {
258
+ Operation *yield = op.region ().front ().getTerminator ();
259
+ return buildTensorExp (op, yield->getOperand (0 ));
260
+ }
261
+
262
+ Optional<unsigned > Merger::buildTensorExp (linalg::GenericOp op, Value val) {
263
+ if (auto arg = val.dyn_cast <BlockArgument>()) {
264
+ unsigned argN = arg.getArgNumber ();
265
+ // Any argument of the generic op that is not marked as a scalar
266
+ // argument is considered a tensor, indexed by the implicit loop
267
+ // bounds. This includes rank-0 tensor arguments.
268
+ if (arg.getOwner ()->getParentOp () == op) {
269
+ OpOperand *t = op.getInputAndOutputOperands ()[argN];
270
+ if (!op.isScalar (t))
271
+ return addExp (Kind::kTensor , argN);
272
+ val = t->get (); // get scalar value
273
+ }
274
+ // Any other argument (marked as scalar argument for the generic op
275
+ // or belonging to an enveloping op) is considered invariant.
276
+ return addExp (Kind::kInvariant , val);
277
+ }
278
+ // Something defined outside is invariant.
279
+ Operation *def = val.getDefiningOp ();
280
+ if (def->getBlock () != &op.region ().front ())
281
+ return addExp (Kind::kInvariant , val);
282
+ // Construct binary operations if subexpressions could be built.
283
+ if (def->getNumOperands () == 2 ) {
284
+ auto x = buildTensorExp (op, def->getOperand (0 ));
285
+ auto y = buildTensorExp (op, def->getOperand (1 ));
286
+ if (x.hasValue () && y.hasValue ()) {
287
+ unsigned e0 = x.getValue ();
288
+ unsigned e1 = y.getValue ();
289
+ if (isa<MulFOp>(def))
290
+ return addExp (Kind::kMulF , e0 , e1 );
291
+ if (isa<MulIOp>(def))
292
+ return addExp (Kind::kMulI , e0 , e1 );
293
+ if (isa<AddFOp>(def))
294
+ return addExp (Kind::kAddF , e0 , e1 );
295
+ if (isa<AddIOp>(def))
296
+ return addExp (Kind::kAddI , e0 , e1 );
297
+ }
298
+ }
299
+ // Cannot build.
300
+ return None;
301
+ }
302
+
245
303
} // namespace sparse_tensor
246
304
} // namespace mlir
0 commit comments