-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
VectorTransforms.h
308 lines (275 loc) · 12.1 KB
/
VectorTransforms.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
//===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
class VectorTransferOpInterface;
namespace scf {
class IfOp;
} // namespace scf
/// Collect a set of patterns to convert from the Vector dialect to itself.
/// Should be merged with populateVectorToSCFLoweringPattern.
void populateVectorToVectorConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> coarseVectorShape = {},
ArrayRef<int64_t> fineVectorShape = {});
namespace vector {
/// Entry point for unrolling declarative pattern rewrites.
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
/// assumed the unrolling factors divide the vector sizes.
/// 2. a fakeFork cast op is inserted that takes the operand and returns
/// `numUnrolledInstances` results of type `unrolledVectorType`.
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
/// result of the fakeFork cast op.
/// 4. a fakeJoin cast op takes all these results and merges them into a
/// single aggregate vector result whose size matches the original
/// non-unrolled op operand types.
///
/// Example:
///
/// opA(operand0, operand1) // numUnrolledInstances = 3
///
/// operand0 operand1
/// | |
/// fork fork
/// <----------gather all fork ops --------->
/// /|\ /|\
/// f00 f01 f02 f10 f11 f12
/// <---------- clone op 3 times --------->
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
/// \ | /
/// <-------------------- join ------------------------->
///
/// Other local patterns then kick in iteratively (including DCE) and compose
/// until all the fakeFork and fakeJoin ops are removed.
///
/// This will be extended in the future to support more advanced use cases than
/// simple pointwise ops.
SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
Operation *op,
ArrayRef<int64_t> targetShape);
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
/// declaratively.
template <typename OpTy>
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
UnrollVectorPattern(
ArrayRef<int64_t> targetShape, MLIRContext *context,
FilterConstraintType constraint = [](OpTy op) { return success(); })
: OpRewritePattern<OpTy>(context),
targetShape(targetShape.begin(), targetShape.end()),
filter(constraint) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
if (failed(filter(op)))
return failure();
auto unrollableVectorOp =
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
if (!unrollableVectorOp)
return failure();
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape)
return failure();
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (op.getOperation()->getNumResults() != 1)
return failure();
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
if (resultVector.size() != 1)
return failure();
rewriter.replaceOp(op, resultVector.front());
return success();
}
private:
SmallVector<int64_t, 4> targetShape;
FilterConstraintType filter;
};
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fastpath, direct cast
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, masked vector.transfer or linalg.copy.
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
LogicalResult splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
scf::IfOp *ifOp = nullptr);
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
filter(filter) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
} // namespace vector
//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
//===----------------------------------------------------------------------===//
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
/// %at = vector.transpose %a, [1, 0]
/// %bRow0 = vector.extract %b[0]
/// %atRow0 = vector.extract %at[0]
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
/// ...
/// %bRowK = vector.extract %b[K]
/// %atRowK = vector.extract %at[K]
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
///
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
// Lower one reduction dimension.
Value lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
};
} // namespace mlir
#endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_