-
Notifications
You must be signed in to change notification settings - Fork 10.8k
/
ConvertStandardToLLVM.h
453 lines (370 loc) · 18.2 KB
/
ConvertStandardToLLVM.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides a dialect conversion targeting the LLVM IR dialect. By default, it
// converts Standard ops and types and provides hooks for dialect-specific
// extensions to the conversion.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
#include "mlir/Transforms/DialectConversion.h"
namespace llvm {
class IntegerType;
class LLVMContext;
class Module;
class Type;
} // namespace llvm
namespace mlir {
class LLVMTypeConverter;
class UnrankedMemRefType;
namespace LLVM {
class LLVMDialect;
class LLVMType;
} // namespace LLVM
/// Set of callbacks that allows the customization of LLVMTypeConverter.
struct LLVMTypeConverterCustomization {
using CustomCallback = std::function<LogicalResult(LLVMTypeConverter &, Type,
SmallVectorImpl<Type> &)>;
/// Customize the type conversion of function arguments.
CustomCallback funcArgConverter;
/// Initialize customization to default callbacks.
LLVMTypeConverterCustomization();
};
/// Callback to convert function argument types. It converts a MemRef function
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result);
public:
using TypeConverter::convertType;
/// Create an LLVMTypeConverter using the default
/// LLVMTypeConverterCustomization.
LLVMTypeConverter(MLIRContext *ctx);
/// Create an LLVMTypeConverter using 'custom' customizations.
LLVMTypeConverter(MLIRContext *ctx,
const LLVMTypeConverterCustomization &custom);
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(ArrayRef<Type> types);
/// Returns the MLIR context.
MLIRContext &getContext();
/// Returns the LLVM context.
llvm::LLVMContext &getLLVMContext();
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }
/// Promote the LLVM struct representation of all MemRef descriptors to stack
/// and use pointers to struct to avoid the complexity of the
/// platform-specific C/C++ ABI lowering related to struct argument passing.
SmallVector<Value, 4> promoteMemRefDescriptors(Location loc,
ValueRange opOperands,
ValueRange operands,
OpBuilder &builder);
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
/// C/C++ ABI lowering related to struct argument passing.
Value promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
/// Creates descriptor structs from individual values constituting them.
Operation *materializeConversion(PatternRewriter &rewriter, Type type,
ArrayRef<Value> values,
Location loc) override;
protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect;
private:
// Convert a function type. The arguments and results are converted one by
// one. Additionally, if the function returns more than one value, pack the
// results into an LLVM IR structure type so that the converted function type
// returns at most one result.
Type convertFunctionType(FunctionType type);
// Convert the index type. Uses llvmModule data layout to create an integer
// of the pointer bitwidth.
Type convertIndexType(IndexType type);
// Convert an integer type `i*` to `!llvm<"i*">`.
Type convertIntegerType(IntegerType type);
// Convert a floating point type: `f16` to `!llvm.half`, `f32` to
// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported
// by LLVM.
Type convertFloatType(FloatType type);
/// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type);
/// Convert a memref type into a list of non-aggregate LLVM IR types that
/// contain all the relevant data. In particular, the list will contain:
/// - two pointers to the memref element type, followed by
/// - an integer offset, followed by
/// - one integer size per dimension of the memref, followed by
/// - one integer stride per dimension of the memref.
/// For example, memref<?x?xf32> is converted to the following list:
/// - `!llvm<"float*">` (allocated pointer),
/// - `!llvm<"float*">` (aligned pointer),
/// - `!llvm.i64` (offset),
/// - `!llvm.i64`, `!llvm.i64` (sizes),
/// - `!llvm.i64`, `!llvm.i64` (strides).
/// These types can be recomposed to a memref descriptor struct.
SmallVector<Type, 5> convertMemRefSignature(MemRefType type);
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
/// that contain all the relevant data. In particular, this list contains:
/// - an integer rank, followed by
/// - a pointer to the memref descriptor struct.
/// For example, memref<*xf32> is converted to the following list:
/// !llvm.i64 (rank)
/// !llvm<"i8*"> (type-erased pointer).
/// These types can be recomposed to a unranked memref descriptor struct.
SmallVector<Type, 2> convertUnrankedMemRefSignature();
// Convert an unranked memref type to an LLVM type that captures the
// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type);
// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type);
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
LLVM::LLVMType getIndexType();
/// Callbacks for customizing the type conversion.
LLVMTypeConverterCustomization customizations;
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// values to a struct.
class StructBuilder {
public:
/// Construct a helper for the given value.
explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
/*implicit*/ operator Value() { return value; }
protected:
// LLVM value
Value value;
// Cached struct type.
Type structType;
protected:
/// Builds IR to extract a value from the struct at position pos
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value memory);
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the aligned pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the aligned pointer into the descriptor.
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the offset from the descriptor.
Value offset(OpBuilder &builder, Location loc);
/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value offset);
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
/// Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size);
/// Builds IR extracting the pos-th size from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
/// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType();
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
/// - aligned pointer;
/// - offset;
/// - <rank> sizes;
/// - <rank> shapes;
/// where <rank> is the MemRef rank as provided in `type`.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, MemRefType type,
ValueRange values);
/// Builds IR extracting individual elements of a MemRef descriptor structure
/// and returning them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
MemRefType type, SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues(MemRefType type);
private:
// Cached index type.
Type indexType;
};
/// Helper class allowing the user to access a range of Values that correspond
/// to an unpacked memref descriptor using named accessors. This does not own
/// the values.
class MemRefDescriptorView {
public:
/// Constructs the view from a range of values. Infers the rank from the size
/// of the range.
explicit MemRefDescriptorView(ValueRange range);
/// Returns the allocated pointer Value.
Value allocatedPtr();
/// Returns the aligned pointer Value.
Value alignedPtr();
/// Returns the offset Value.
Value offset();
/// Returns the pos-th size Value.
Value size(unsigned pos);
/// Returns the pos-th stride Value.
Value stride(unsigned pos);
private:
/// Rank of the memref the descriptor is pointing to.
int rank;
/// Underlying range of Values.
ValueRange elements;
};
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc);
/// Builds IR setting the rank in the descriptor
void setRank(OpBuilder &builder, Location loc, Value value);
/// Builds IR extracting ranked memref descriptor ptr
Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
/// Builds IR populating an unranked MemRef descriptor structure from a list
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
static Value pack(OpBuilder &builder, Location loc,
LLVMTypeConverter &converter, UnrankedMemRefType type,
ValueRange values);
/// Builds IR extracting individual elements that compose an unranked memref
/// descriptor and returns them as `results` list.
static void unpack(OpBuilder &builder, Location loc, Value packed,
SmallVectorImpl<Value> &results);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with access to an LLVMTypeConverter.
class ConvertToLLVMPattern : public ConversionPattern {
public:
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
/// Returns the LLVM IR context.
llvm::LLVMContext &getContext() const;
/// Returns the LLVM IR module associated with the LLVM dialect.
llvm::Module &getModule() const;
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
/// defined by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const;
/// Gets the MLIR type wrapping the LLVM void type.
LLVM::LLVMType getVoidType() const;
/// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const;
/// Create an LLVM dialect operation defining the given index constant.
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const;
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;
};
/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
template <typename OpTy>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(OpTy::getOperationName(),
&typeConverter.getContext(), typeConverter,
benefit) {}
};
namespace LLVM {
namespace detail {
/// Replaces the given operaiton "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace detail
} // namespace LLVM
/// Generic implementation of one-to-one conversion from "SourceOp" to
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
/// Upholds a convention that multi-result operations get converted into an
/// operation returning the LLVM IR structure type, in which case individual
/// values must be extacted from using LLVM::ExtractValueOp before being used.
template <typename SourceOp, typename TargetOp>
class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
operands, this->typeConverter,
rewriter);
}
};
/// Derived class that automatically populates legalization information for
/// different LLVM ops.
class LLVMConversionTarget : public ConversionTarget {
public:
explicit LLVMConversionTarget(MLIRContext &ctx);
};
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H