/
MHLOToLinalgOnTensors.cpp
288 lines (255 loc) · 11.4 KB
/
MHLOToLinalgOnTensors.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- XLAToLinalgOnTensors.cpp - Pass to convert XLA to Linalg on tensors-===//
//
// Pass to convert from XLA to linalg on tensers. Uses the patterns from
// tensorflow/compiler/mlir/xla/transforms/legalize_to_linalg.cc along with
// some IREE specific patterns.
//
//===----------------------------------------------------------------------===//
#include <memory>
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h"
#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/MHLO/Rewriters.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
namespace iree_compiler {
namespace {
//===----------------------------------------------------------------------===//
// mhlo.concatenate conversion patterns.
//===----------------------------------------------------------------------===//
namespace {
/// Converts mhlo.concatenate operation to extract_slice ops + insert_slice ops.
struct ConcatenateOpConversion
: public OpConversionPattern<mhlo::ConcatenateOp> {
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConcatenateOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
auto resultType = this->typeConverter->convertType(op.getResult().getType())
.dyn_cast<RankedTensorType>();
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shape for output");
}
Location loc = op.getLoc();
int dim = op.dimension();
int rank = resultType.getRank();
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0; i < rank; ++i) {
offsets.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
strides.push_back(rewriter.create<ConstantIndexOp>(loc, 1));
}
Value resultDimSize = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
auto size = rewriter.create<tensor::DimOp>(loc, arg, dim);
resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultType.getShape(), resultType.getElementType());
auto zeroAttr = rewriter.getZeroAttr(resultType.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
Value result =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
Value accBound = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
offsets[dim] = accBound;
sizes[dim] = rewriter.create<tensor::DimOp>(loc, arg, dim);
result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
sizes, strides);
accBound = rewriter.create<AddIOp>(loc, accBound, sizes[dim]);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// mhlo.fft conversion patterns.
//===----------------------------------------------------------------------===//
namespace {
/// Creats coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
bool isRealPart) {
// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getDimSize(0);
SmallVector<Attribute> values;
assert(matrixType.getRank() == 2 && "expected 2D matrix");
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
values.push_back(b.getF32FloatAttr(v));
}
}
return b.create<ConstantOp>(loc, matrixType,
DenseFPElementsAttr::get(matrixType, values));
}
Value createLinalgMatmulOnTensors(OpBuilder b, Location loc,
RankedTensorType resultType, Value lhs,
Value rhs) {
Value zero =
b.create<ConstantOp>(loc, b.getZeroAttr(resultType.getElementType()));
auto initTensor = b.create<linalg::InitTensorOp>(
loc, /*dyn_size=*/ValueRange{}, resultType.getShape(),
resultType.getElementType());
Value zeroTensor =
b.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
switch (lhs.getType().cast<RankedTensorType>().getRank()) {
case 1:
return b
.create<linalg::VecmatOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs},
ValueRange{zeroTensor})
.getResult(0);
case 2:
return b
.create<linalg::MatmulOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs},
ValueRange{zeroTensor})
.getResult(0);
default:
llvm_unreachable("unhandled matmul type");
}
}
/// Converts mhlo.fft operation to Linalg ops.
struct FftOpConversion : public OpConversionPattern<mhlo::FftOp> {
using OpConversionPattern<mhlo::FftOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::FftOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const override {
if (op.fft_type() != "RFFT") {
return rewriter.notifyMatchFailure(op,
"non RFFT types are supported yet");
}
mhlo::FftOpAdaptor adaptor(args);
auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() > 2) {
return rewriter.notifyMatchFailure(op, "only static 1D or 2D dft ops");
}
int rank = inputType.getRank();
int n = inputType.getDimSize(rank - 1);
int fftLength =
op.fft_length().getSplatValue().cast<IntegerAttr>().getInt() / 2 + 1;
Location loc = op.getLoc();
auto matrixType =
RankedTensorType::get({n, fftLength}, inputType.getElementType());
auto resultType =
RankedTensorType::get(op.getType().cast<RankedTensorType>().getShape(),
inputType.getElementType());
auto realMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
auto real = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.operand(), realMatrix);
auto imagMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
auto imag = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.operand(), imagMatrix);
// Pack the results back to mhlo::ComplexOp.
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, op.getType(), real, imag);
return success();
}
};
} // namespace
struct ConvertMHLOToLinalgOnTensorsPass
: public ConvertMHLOToLinalgOnTensorsBase<
ConvertMHLOToLinalgOnTensorsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Flow::FlowDialect, linalg::LinalgDialect,
mhlo::MhloDialect, ShapeDialect, math::MathDialect,
memref::MemRefDialect, complex::ComplexDialect>();
}
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
MLIRContext *context = &getContext();
auto typeConverter = mhlo::createHloToLinalgSignedIntegerConverter();
// NOTE: not using corresponding setupMHLOToFlowPatterns because the entire
// MHLO dialects are marked illegal by this pass.
// TODO: Collapse/rework all of these patterns once the consolidation
// lands. There is little reason to have these so spread out.
populateMHLOToFlowPatterns(context, patterns);
chlo::PopulateDecomposeChloPatterns(context, &patterns);
populateMHLOBroadcastingToLinalgPatterns(context, *typeConverter, patterns);
populateMHLOToLinalgOnTensorsConversionPatterns(context, *typeConverter,
patterns);
populateMHLOComplexToRealPatterns(context, *typeConverter, patterns);
ConversionTarget target(getContext());
target.addIllegalDialect<chlo::HloClientDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
// Let the rest fall through.
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
/// Convert mhlo.constant op into std.const.
struct ConstOpConversion : public OpConversionPattern<mhlo::ConstOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConstOp op, ArrayRef<Value> /*operands*/,
ConversionPatternRewriter &rewriter) const override {
auto valueAttr = op.value();
Type oldElType = valueAttr.getType().getElementType();
Type newElType = this->typeConverter->convertType(oldElType);
ElementsAttr newValueAttr = valueAttr;
if (newElType != oldElType) {
// Values don't change, just their reported type.
newValueAttr = valueAttr.mapValues(
newElType, [](const APInt &oldEl) { return oldEl; });
}
rewriter.replaceOpWithNewOp<ConstantOp>(op, newValueAttr);
return success();
}
};
} // namespace
void populateMHLOToLinalgOnTensorsConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, typeConverter, &patterns);
// TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version
// then remove the PatternBenefit here
patterns.insert<ConstOpConversion, ConcatenateOpConversion, FftOpConversion>(
typeConverter, context, PatternBenefit(1000));
}
std::unique_ptr<OperationPass<FuncOp>> createMHLOToLinalgOnTensorsPass() {
return std::make_unique<ConvertMHLOToLinalgOnTensorsPass>();
}
} // namespace iree_compiler
} // namespace mlir