Skip to content

Commit 684dfe8

Browse files
committed
[mlir] factor out ConvertToLLVMPattern
This class and classes that extend it are general utilities for any dialect that is being converted into the LLVM dialect. They are in no way specific to Standard-to-LLVM conversion and should not make their users depend on it. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105542
1 parent 9ced1e4 commit 684dfe8

File tree

12 files changed

+732
-641
lines changed

12 files changed

+732
-641
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
//===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10+
#define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11+
12+
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
13+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Transforms/DialectConversion.h"
15+
16+
namespace mlir {
17+
18+
namespace LLVM {
19+
namespace detail {
20+
/// Replaces the given operation "op" with a new operation of type "targetOp"
21+
/// and given operands.
22+
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
23+
ValueRange operands,
24+
LLVMTypeConverter &typeConverter,
25+
ConversionPatternRewriter &rewriter);
26+
} // namespace detail
27+
} // namespace LLVM
28+
29+
/// Base class for operation conversions targeting the LLVM IR dialect. It
30+
/// provides the conversion patterns with access to the LLVMTypeConverter and
31+
/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
32+
/// LowerToLLVMOptions by reference meaning the references have to remain alive
33+
/// during the entire pattern lifetime.
34+
class ConvertToLLVMPattern : public ConversionPattern {
35+
public:
36+
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
37+
LLVMTypeConverter &typeConverter,
38+
PatternBenefit benefit = 1);
39+
40+
protected:
41+
/// Returns the LLVM dialect.
42+
LLVM::LLVMDialect &getDialect() const;
43+
44+
LLVMTypeConverter *getTypeConverter() const;
45+
46+
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
47+
/// defined by the used type converter.
48+
Type getIndexType() const;
49+
50+
/// Gets the MLIR type wrapping the LLVM integer type whose bit width
51+
/// corresponds to that of a LLVM pointer type.
52+
Type getIntPtrType(unsigned addressSpace = 0) const;
53+
54+
/// Gets the MLIR type wrapping the LLVM void type.
55+
Type getVoidType() const;
56+
57+
/// Get the MLIR type wrapping the LLVM i8* type.
58+
Type getVoidPtrType() const;
59+
60+
/// Create a constant Op producing a value of `resultType` from an index-typed
61+
/// integer attribute.
62+
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
63+
Type resultType, int64_t value);
64+
65+
/// Create an LLVM dialect operation defining the given index constant.
66+
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
67+
uint64_t value) const;
68+
69+
// This is a strided getElementPtr variant that linearizes subscripts as:
70+
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
71+
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
72+
ValueRange indices,
73+
ConversionPatternRewriter &rewriter) const;
74+
75+
/// Returns if the given memref has identity maps and the element type is
76+
/// convertible to LLVM.
77+
bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
78+
79+
/// Returns the type of a pointer to an element of the memref.
80+
Type getElementPtrType(MemRefType type) const;
81+
82+
/// Computes sizes, strides and buffer size in bytes of `memRefType` with
83+
/// identity layout. Emits constant ops for the static sizes of `memRefType`,
84+
/// and uses `dynamicSizes` for the others. Emits instructions to compute
85+
/// strides and buffer size from these sizes.
86+
///
87+
/// For example, memref<4x?xf32> emits:
88+
/// `sizes[0]` = llvm.mlir.constant(4 : index) : i64
89+
/// `sizes[1]` = `dynamicSizes[0]`
90+
/// `strides[1]` = llvm.mlir.constant(1 : index) : i64
91+
/// `strides[0]` = `sizes[0]`
92+
/// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64
93+
/// %nullptr = llvm.mlir.null : !llvm.ptr<f32>
94+
/// %gep = llvm.getelementptr %nullptr[%size]
95+
/// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
96+
/// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
97+
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
98+
ValueRange dynamicSizes,
99+
ConversionPatternRewriter &rewriter,
100+
SmallVectorImpl<Value> &sizes,
101+
SmallVectorImpl<Value> &strides,
102+
Value &sizeBytes) const;
103+
104+
/// Computes the size of type in bytes.
105+
Value getSizeInBytes(Location loc, Type type,
106+
ConversionPatternRewriter &rewriter) const;
107+
108+
/// Computes total number of elements for the given shape.
109+
Value getNumElements(Location loc, ArrayRef<Value> shape,
110+
ConversionPatternRewriter &rewriter) const;
111+
112+
/// Creates and populates a canonical memref descriptor struct.
113+
MemRefDescriptor
114+
createMemRefDescriptor(Location loc, MemRefType memRefType,
115+
Value allocatedPtr, Value alignedPtr,
116+
ArrayRef<Value> sizes, ArrayRef<Value> strides,
117+
ConversionPatternRewriter &rewriter) const;
118+
};
119+
120+
/// Utility class for operation conversions targeting the LLVM dialect that
121+
/// match exactly one source operation.
122+
template <typename SourceOp>
123+
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
124+
public:
125+
explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
126+
PatternBenefit benefit = 1)
127+
: ConvertToLLVMPattern(SourceOp::getOperationName(),
128+
&typeConverter.getContext(), typeConverter,
129+
benefit) {}
130+
131+
/// Wrappers around the RewritePattern methods that pass the derived op type.
132+
void rewrite(Operation *op, ArrayRef<Value> operands,
133+
ConversionPatternRewriter &rewriter) const final {
134+
rewrite(cast<SourceOp>(op), operands, rewriter);
135+
}
136+
LogicalResult match(Operation *op) const final {
137+
return match(cast<SourceOp>(op));
138+
}
139+
LogicalResult
140+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
141+
ConversionPatternRewriter &rewriter) const final {
142+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
143+
}
144+
145+
/// Rewrite and Match methods that operate on the SourceOp type. These must be
146+
/// overridden by the derived pattern class.
147+
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
148+
ConversionPatternRewriter &rewriter) const {
149+
llvm_unreachable("must override rewrite or matchAndRewrite");
150+
}
151+
virtual LogicalResult match(SourceOp op) const {
152+
llvm_unreachable("must override match or matchAndRewrite");
153+
}
154+
virtual LogicalResult
155+
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
156+
ConversionPatternRewriter &rewriter) const {
157+
if (succeeded(match(op))) {
158+
rewrite(op, operands, rewriter);
159+
return success();
160+
}
161+
return failure();
162+
}
163+
164+
private:
165+
using ConvertToLLVMPattern::match;
166+
using ConvertToLLVMPattern::matchAndRewrite;
167+
};
168+
169+
/// Generic implementation of one-to-one conversion from "SourceOp" to
170+
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
171+
/// Upholds a convention that multi-result operations get converted into an
172+
/// operation returning the LLVM IR structure type, in which case individual
173+
/// values must be extracted from using LLVM::ExtractValueOp before being used.
174+
template <typename SourceOp, typename TargetOp>
175+
class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
176+
public:
177+
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
178+
using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
179+
180+
/// Converts the type of the result to an LLVM type, pass operands as is,
181+
/// preserve attributes.
182+
LogicalResult
183+
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
184+
ConversionPatternRewriter &rewriter) const override {
185+
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
186+
operands, *this->getTypeConverter(),
187+
rewriter);
188+
}
189+
};
190+
191+
} // namespace mlir
192+
193+
#endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
10+
#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
11+
12+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
15+
namespace mlir {
16+
17+
namespace LLVM {
18+
namespace detail {
19+
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
20+
// 1-D LLVM vectors.
21+
struct NDVectorTypeInfo {
22+
// LLVM array struct which encodes n-D vectors.
23+
Type llvmNDVectorTy;
24+
// LLVM vector type which encodes the inner 1-D vector type.
25+
Type llvm1DVectorTy;
26+
// Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
27+
SmallVector<int64_t, 4> arraySizes;
28+
};
29+
30+
// For >1-D vector types, extracts the necessary information to iterate over all
31+
// 1-D subvectors in the underlying llrepresentation of the n-D vector
32+
// Iterates on the llvm array type until we hit a non-array type (which is
33+
// asserted to be an llvm vector type).
34+
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
35+
LLVMTypeConverter &converter);
36+
37+
// Express `linearIndex` in terms of coordinates of `basis`.
38+
// Returns the empty vector when linearIndex is out of the range [0, P] where
39+
// P is the product of all the basis coordinates.
40+
//
41+
// Prerequisites:
42+
// Basis is an array of nonnegative integers (signed type inherited from
43+
// vector shape type).
44+
SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
45+
unsigned linearIndex);
46+
47+
// Iterate of linear index, convert to coords space and insert splatted 1-D
48+
// vector in each position.
49+
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
50+
function_ref<void(ArrayAttr)> fun);
51+
52+
LogicalResult handleMultidimensionalVectors(
53+
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
54+
std::function<Value(Type, ValueRange)> createOperand,
55+
ConversionPatternRewriter &rewriter);
56+
57+
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58+
ValueRange operands,
59+
LLVMTypeConverter &typeConverter,
60+
ConversionPatternRewriter &rewriter);
61+
} // namespace detail
62+
} // namespace LLVM
63+
64+
/// Basic lowering implementation to rewrite Ops with just one result to the
65+
/// LLVM Dialect. This supports higher-dimensional vector types.
66+
template <typename SourceOp, typename TargetOp>
67+
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
68+
public:
69+
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
70+
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
71+
72+
LogicalResult
73+
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
74+
ConversionPatternRewriter &rewriter) const override {
75+
static_assert(
76+
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
77+
"expected single result op");
78+
return LLVM::detail::vectorOneToOneRewrite(
79+
op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
80+
rewriter);
81+
}
82+
};
83+
} // namespace mlir
84+
85+
#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H

0 commit comments

Comments
 (0)