/
Tiling.cpp
433 lines (396 loc) · 18.3 KB
/
Tiling.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace linalg_ext {
//===----------------------------------------------------------------------===//
// Utility methods for tiling a linalg_ext operation that implements a
// TiledOpInterface
//===----------------------------------------------------------------------===//
/// Returns failure if the options are unsupported.
static LogicalResult verifySupportedTilingOptions(
PatternRewriter &rewriter, Operation *op,
const linalg::LinalgTilingOptions &options) {
if (!options.interchangeVector.empty()) {
return rewriter.notifyMatchFailure(op,
"unsupported interchange during tiling");
}
if (options.paddingValueComputationFunction) {
return rewriter.notifyMatchFailure(op, "unsupported tile + pad option");
}
if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
return rewriter.notifyMatchFailure(op,
"only tiling with scf.for is supported");
}
if (options.distribution) {
if (llvm::any_of(options.distribution->distributionMethod,
[](linalg::DistributionMethod method) {
return method != linalg::DistributionMethod::Cyclic;
})) {
return rewriter.notifyMatchFailure(op,
"only cyclic distibution is allowed");
}
}
return success();
}
/// Converts a `Value` to an `OpFoldRedult` by extracting the constant value if
/// the value is defined by a constant op.
static OpFoldResult getOpFoldResult(Value value) {
IntegerAttr::ValueType attr;
if (matchPattern(value, m_ConstantInt(&attr))) {
return IntegerAttr::get(value.getType(), attr);
}
return value;
}
static SmallVector<OpFoldResult, 4> getOpFoldResult(ArrayRef<Value> values) {
return llvm::to_vector<4>(llvm::map_range(
values, [](Value value) { return getOpFoldResult(value); }));
}
/// Converts an `OpFoldResult` to a `Value` by building a constant op if
/// if the `OpFoldResult` is an `IntegerAttr`.
static Value getValue(OpBuilder &builder, Location loc,
OpFoldResult valueOrAttr) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
return builder.create<ConstantIndexOp>(loc,
attr.cast<IntegerAttr>().getInt());
}
return valueOrAttr.get<Value>();
}
/// Returns true if loop is untiled. Only checks if the value is statically
/// zero. It is assumed that a `Value` defined by a constant op is already
/// converted to an `IntegerAttr` of that value. So here just return true if
/// this is an attribute with a zero value.
static bool isUntiledLoop(OpFoldResult valueOrAttr) {
auto attr = valueOrAttr.dyn_cast<Attribute>();
return attr && attr.cast<IntegerAttr>().getValue() == 0;
}
/// Generates the tiled loops and the body by invoking the interface methods of
/// TiledOpInterface.
/// - `outputs` are the operands to use for outputs of the tiled operation.
/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
/// loop is to be untiled it is set to 0.
/// - `iteratorType` is the type of the loop iterator returned by the
/// TiledOpInterface.
/// - `loopBounds` are the bounds of all the loops of the op returned by the
/// TiledOpInterface.
/// - `loopDepth` is the current loop depth being processed.
/// - `offsets` are the `Value`s that represent the position of the tile being
/// operated on. The offsets are computed as the tiled loops are being
/// generated.
/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
/// distributed loops. It is a stack, and once an entry at the top of the
/// stack is used for distribution it is popped before processing the inner
/// loops.
static FailureOr<TiledOp> tileLinalgExtOpImpl(
OpBuilder &builder, TiledOpInterface op, ValueRange outputs,
MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
ArrayRef<Range> loopBounds, unsigned loopDepth,
SmallVectorImpl<OpFoldResult> &offsets,
ArrayRef<linalg::ProcInfo> distributionInfo) {
Location loc = op.getLoc();
// If this is the innermost loop, then generated the tiled implementation of
// the op by invoking the TiledOpInterface methods.
if (loopDepth == tileSizes.size()) {
SmallVector<SmallVector<OpFoldResult, 4>> resultOffsets;
Operation *tiledOp = op.getTiledImplementation(builder, outputs, offsets,
tileSizes, resultOffsets);
if (!tiledOp) {
return static_cast<LogicalResult>(
op.emitOpError("failed to get tiled implementation"));
}
assert(tiledOp->getNumResults() == 0 ||
(resultOffsets.size() == tiledOp->getNumResults()));
TiledOp ret;
ret.op = tiledOp;
// If the operation has results, then the result of the tiled operation is
// to be inserted into the `initValues` and returned.
if (tiledOp->getNumResults()) {
SmallVector<Value> results;
results.reserve(tiledOp->getNumResults());
for (auto en : llvm::enumerate(tiledOp->getResults())) {
Value result = en.value();
ArrayRef<OpFoldResult> offsets(resultOffsets[en.index()]);
auto resultType = result.getType().cast<ShapedType>();
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(resultType.getRank(), oneAttr);
auto sizes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, resultType.getRank()),
[&](int64_t dim) { return getDim(builder, loc, result, dim); }));
Value insert = builder.create<tensor::InsertSliceOp>(
loc, result, outputs[en.index()], offsets, sizes, strides);
results.push_back(insert);
}
std::swap(ret.results, results);
}
return ret;
}
// If tile size at this depth is empty, do nothing.
if (isUntiledLoop(tileSizes[loopDepth])) {
auto zeroAttr = builder.getI64IntegerAttr(0);
offsets.push_back(zeroAttr);
assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
"expected loop bounds to have lower bound of zero");
tileSizes[loopDepth] = getOpFoldResult(loopBounds[loopDepth].size);
return tileLinalgExtOpImpl(builder, op, outputs, tileSizes, iteratorTypes,
loopBounds, loopDepth + 1, offsets,
distributionInfo);
}
// Generate an scf.for for the current loop depth.
Value lb = loopBounds[loopDepth].offset;
Value ub = loopBounds[loopDepth].size;
if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
return static_cast<LogicalResult>(
op.emitOpError("expected stride to be 1"));
}
Value step = getValue(builder, loc, tileSizes[loopDepth]);
// Update lb, ub and step for cyclic distribution.
if (!distributionInfo.empty() &&
iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
linalg::updateBoundsForCyclicDistribution(
builder, loc, distributionInfo.front().procId,
distributionInfo.front().nprocs, lb, ub, step);
distributionInfo = distributionInfo.drop_front();
}
FailureOr<TiledOp> innerReturnValue;
bool isBufferTiling = op->getNumResults() == 0;
ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
auto forOp = builder.create<scf::ForOp>(
loc, lb, ub, step, initValues,
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
offsets.push_back(iv);
auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
b.getAffineSymbolExpr(0),
b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
// Similar to linalg tiling, the tile size is the min(tileSizes, ub -
// iv) to account for cases where tile size does not divide (ub - lb)
// exactly.
Value inBoundsTileSize = b.create<AffineMinOp>(
loc, affineMaps,
ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
tileSizes[loopDepth] = getOpFoldResult(inBoundsTileSize);
// Recursively proceed to generate the tiled loop for the next level.
innerReturnValue = tileLinalgExtOpImpl(
b, op, (isBufferTiling ? outputs : args), tileSizes, iteratorTypes,
loopBounds, loopDepth + 1, offsets, distributionInfo);
if (failed(innerReturnValue)) return;
b.create<scf::YieldOp>(loc, innerReturnValue->results);
});
if (failed(innerReturnValue)) {
return innerReturnValue;
}
innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
forOp.getOperation());
innerReturnValue->results = forOp.getResults();
return innerReturnValue;
}
FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, LinalgExtOp op,
const linalg::LinalgTilingOptions &options) {
TiledOpInterface tilableOp = dyn_cast<TiledOpInterface>(op.getOperation());
if (!tilableOp) return TiledOp{};
SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
SmallVector<Value, 4> tileSizesVals =
options.tileSizeComputationFunction(b, tilableOp.getOperation());
auto zeroAttr = b.getI64IntegerAttr(0);
// The actual tile sizes used converts `Value` defined as constant 0, to a
// zero integer attributes. Currently if the iterator type is not "parallel",
// the tile size is forced to zero as well.
auto tileSizes = getOpFoldResult(tileSizesVals);
tileSizes.resize(iteratorTypes.size(), zeroAttr);
for (auto en : llvm::enumerate(iteratorTypes)) {
if (en.value() == getParallelIteratorTypeName()) continue;
if (!isUntiledLoop(tileSizes[en.index()])) {
return static_cast<LogicalResult>(op.emitOpError(
"unimplemented tiling of non-parallel loop iterator type"));
}
}
// Trivial early exit case of tile sizes being zero for all parallel loops.
if (llvm::all_of(tileSizes, isUntiledLoop)) {
return TiledOp{op.getOperation(), {}, {}};
}
SmallVector<Range> loopBounds = tilableOp.getLoopBounds(b);
SmallVector<linalg::ProcInfo> distributionInfo;
// If the tiled loops are distributed, get the proc_id and nprocs for the
// distributed loops. First collect the parallel loops by iterating over the
// tileSizes and getting the loops that are distribute, i.e.,
// - parallel, i.e. iteratorTypes is "parallel"
// - tiled, i.e. tileSize != 0
if (options.distribution) {
SmallVector<Range> distributedLoopRange;
for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
if (isUntiledLoop(tileSizes[i])) continue;
if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
distributedLoopRange.push_back(loopBounds[i]);
}
distributionInfo =
options.distribution->procInfo(b, op.getLoc(), distributedLoopRange);
}
SmallVector<OpFoldResult> offsets;
return tileLinalgExtOpImpl(b, tilableOp, op.outputs(), tileSizes,
iteratorTypes, loopBounds, 0, offsets,
distributionInfo);
}
//===----------------------------------------------------------------------===//
// Patterns for tiling LinalgExtOps.
//===----------------------------------------------------------------------===//
namespace {
/// Base pattern for tiling LinalgExtOps.
struct LinalgExtBaseTilingPattern : public RewritePattern {
LinalgExtBaseTilingPattern(StringRef opName, MLIRContext *context,
linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: RewritePattern(opName, benefit, context),
filter(filter),
options(options) {}
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
TiledOp &result) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
linalg::LinalgTransformationFilter filter;
/// Options to control tiling;
linalg::LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgExtTilingPattern : public LinalgExtBaseTilingPattern {
LinalgExtTilingPattern(MLIRContext *context,
linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgExtBaseTilingPattern(OpTy::getOperationName(), context, options,
filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledOp tiledOp;
// Check for failure.
if (failed(LinalgExtBaseTilingPattern::matchAndRewriteBase(op, rewriter,
tiledOp))) {
return failure();
}
// Check for do-nothing case.
if (!tiledOp.op) return failure();
if (tiledOp.op != op) {
if (tiledOp.results.empty()) {
rewriter.eraseOp(op);
} else {
rewriter.replaceOp(op, tiledOp.results);
}
}
return success();
}
};
} // namespace
LogicalResult LinalgExtBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter, TiledOp &result) const {
auto linalgExtOp = dyn_cast<LinalgExtOp>(op);
if (!linalgExtOp) return failure();
if (failed(filter.checkAndNotify(rewriter, op))) return failure();
if (failed(verifySupportedTilingOptions(rewriter, op, options))) {
return failure();
}
FailureOr<TiledOp> res = tileLinalgExtOp(rewriter, linalgExtOp, options);
if (failed(res)) return res;
result = *res;
if (result.op) {
filter.replaceLinalgTransformationFilter(rewriter, result.op);
}
return success();
}
//===----------------------------------------------------------------------===//
// Test pass for tiling Linalg Ext ops
//===----------------------------------------------------------------------===//
namespace {
struct LinalgExtTilingPass : public LinalgExtTilingBase<LinalgExtTilingPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
memref::MemRefDialect, StandardOpsDialect,
tensor::TensorDialect, scf::SCFDialect>();
}
void runOnOperation() override;
};
} // namespace
template <typename OpTy>
static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
}
void LinalgExtTilingPass::runOnOperation() {
FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgExtTilingPattern<ScatterOp>>(
context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
linalg::LinalgTransformationFilter(
Identifier::get("tiling_input", context),
Identifier::get("tiling_output", context)));
patterns.add<LinalgExtTilingPattern<ScatterOp>>(
context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
linalg::LinalgTransformationFilter(
Identifier::get("no_tiling_input", context),
Identifier::get("no_tiling_output", context)));
patterns.add<LinalgExtTilingPattern<SortOp>>(
context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
linalg::LinalgTransformationFilter(
Identifier::get("outer_reduce_input", context),
Identifier::get("outer_reduce_output", context)));
patterns.add<LinalgExtTilingPattern<SortOp>>(
context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
linalg::LinalgTransformationFilter(
Identifier::get("inner_reduce_input", context),
Identifier::get("inner_reduce_output", context)));
static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
[](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
auto numParallelDims = parallelLoopRanges.size();
SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
for (size_t dim = 0; dim < numParallelDims; ++dim) {
procInfo[numParallelDims - dim - 1] = {
buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp>(
builder, dim),
buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupCountOp>(
builder, dim)};
}
return procInfo;
},
{linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
linalg::DistributionMethod::Cyclic},
DenseMap<StringRef,
std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
patterns
.add<LinalgExtTilingPattern<ScatterOp>, LinalgExtTilingPattern<SortOp>>(
context,
linalg::LinalgTilingOptions()
.setTileSizes(ArrayRef<int64_t>{10, 0, 30})
.setDistributionOptions(workgroupDistributionOptions),
linalg::LinalgTransformationFilter(
Identifier::get("distribute_input", context),
Identifier::get("distribute_output", context)));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
std::unique_ptr<OperationPass<FuncOp>> createLinalgExtTilingPass() {
return std::make_unique<LinalgExtTilingPass>();
}
} // namespace linalg_ext
} // namespace iree_compiler
} // namespace mlir