17
17
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
18
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
20
+ #include " mlir/IR/Builders.h"
21
+ #include " mlir/IR/BuiltinTypes.h"
22
+ #include " mlir/IR/OpDefinition.h"
20
23
#include " mlir/Support/MathExtras.h"
21
24
#include " mlir/Transforms/DialectConversion.h"
22
25
#include " llvm/Support/FormatVariadic.h"
@@ -29,6 +32,26 @@ using namespace mlir;
29
32
// Utility functions
30
33
// ===----------------------------------------------------------------------===//
31
34
35
+ // / Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first
36
+ // / memref::AtomicRMWOp sets the destination bits to all zero to prepare the
37
+ // / destination byte to be written to. The second memref::AtomicRMWOp does the
38
+ // / writing of the value to store, using an `ori` type operation. The value
39
+ // / to store and the write mask should both have the destination type bitwidth,
40
+ // / and the bits of the value to store should be all zero except for the bits
41
+ // / aligned with the store destination.
42
+ static void replaceStoreWithAtomics (ConversionPatternRewriter &rewriter,
43
+ memref::StoreOp op, Value writeMask,
44
+ Value storeVal, Value memref,
45
+ ValueRange storeIndices) {
46
+ // Clear destination bits
47
+ rewriter.create <memref::AtomicRMWOp>(op.getLoc (), arith::AtomicRMWKind::andi,
48
+ writeMask, memref, storeIndices);
49
+ // Write srcs bits to destination
50
+ rewriter.create <memref::AtomicRMWOp>(op->getLoc (), arith::AtomicRMWKind::ori,
51
+ storeVal, memref, storeIndices);
52
+ rewriter.eraseOp (op);
53
+ }
54
+
32
55
// / When data is loaded/stored in `targetBits` granularity, but is used in
33
56
// / `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
34
57
// / treated as an array of elements of width `sourceBits`.
@@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
43
66
AffineExpr s0;
44
67
bindSymbols (builder.getContext (), s0);
45
68
int scaleFactor = targetBits / sourceBits;
46
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply (
47
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
69
+ AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
70
+ OpFoldResult offsetVal =
71
+ affine::makeComposedFoldedAffineApply (builder, loc, offsetExpr, {srcIdx});
48
72
Value bitOffset = getValueOrCreateConstantIndexOp (builder, loc, offsetVal);
49
73
IntegerType dstType = builder.getIntegerType (targetBits);
50
74
return builder.create <arith::IndexCastOp>(loc, dstType, bitOffset);
51
75
}
52
76
77
+ // / When writing a subbyte size, writing needs to happen atomically in case of
78
+ // / another write happening on the same byte at the same time. To do the write,
79
+ // / we first must clear `dstBits` at the `linearizedIndices` of the subbyte
80
+ // / store. This function returns the appropriate mask for clearing these bits.
81
+ static Value getAtomicWriteMask (Location loc, OpFoldResult linearizedIndices,
82
+ int64_t srcBits, int64_t dstBits,
83
+ Value bitwidthOffset, OpBuilder &builder) {
84
+ auto dstIntegerType = builder.getIntegerType (dstBits);
85
+ auto maskRightAlignedAttr =
86
+ builder.getIntegerAttr (dstIntegerType, (1 << srcBits) - 1 );
87
+ Value maskRightAligned =
88
+ builder
89
+ .create <arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
90
+ .getResult ();
91
+ Value writeMaskInverse =
92
+ builder.create <arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
93
+ auto flipValAttr = builder.getIntegerAttr (dstIntegerType, -1 );
94
+ Value flipVal =
95
+ builder.create <arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
96
+ .getResult ();
97
+ return builder.create <arith::XOrIOp>(loc, writeMaskInverse, flipVal);
98
+ }
99
+
100
+ // / Returns the scaled linearized index based on the `srcBits` and `dstBits`
101
+ // / sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
102
+ // / the returned index has the granularity of `dstBits`
103
+ static Value getIndicesForLoadOrStore (OpBuilder &builder, Location loc,
104
+ OpFoldResult linearizedIndex,
105
+ int64_t srcBits, int64_t dstBits) {
106
+ AffineExpr s0;
107
+ bindSymbols (builder.getContext (), s0);
108
+ int64_t scaler = dstBits / srcBits;
109
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply (
110
+ builder, loc, s0.floorDiv (scaler), {linearizedIndex});
111
+ return getValueOrCreateConstantIndexOp (builder, loc, scaledLinearizedIndices);
112
+ }
113
+
114
+ static OpFoldResult
115
+ getLinearizedSrcIndices (OpBuilder &builder, Location loc, int64_t srcBits,
116
+ const SmallVector<OpFoldResult> &indices,
117
+ Value memref) {
118
+ auto stridedMetadata =
119
+ builder.create <memref::ExtractStridedMetadataOp>(loc, memref);
120
+ OpFoldResult linearizedIndices;
121
+ std::tie (std::ignore, linearizedIndices) =
122
+ memref::getLinearizedMemRefOffsetAndSize (
123
+ builder, loc, srcBits, srcBits,
124
+ stridedMetadata.getConstifiedMixedOffset (),
125
+ stridedMetadata.getConstifiedMixedSizes (),
126
+ stridedMetadata.getConstifiedMixedStrides (), indices);
127
+ return linearizedIndices;
128
+ }
129
+
53
130
namespace {
54
131
55
132
// ===----------------------------------------------------------------------===//
@@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
155
232
bitsLoad = rewriter.create <memref::LoadOp>(loc, adaptor.getMemref (),
156
233
ValueRange{});
157
234
} else {
158
- SmallVector<OpFoldResult> indices =
159
- getAsOpFoldResult (adaptor.getIndices ());
160
-
161
- auto stridedMetadata = rewriter.create <memref::ExtractStridedMetadataOp>(
162
- loc, op.getMemRef ());
163
-
164
235
// Linearize the indices of the original load instruction. Do not account
165
236
// for the scaling yet. This will be accounted for later.
166
- OpFoldResult linearizedIndices;
167
- std::tie (std::ignore, linearizedIndices) =
168
- memref::getLinearizedMemRefOffsetAndSize (
169
- rewriter, loc, srcBits, srcBits,
170
- stridedMetadata.getConstifiedMixedOffset (),
171
- stridedMetadata.getConstifiedMixedSizes (),
172
- stridedMetadata.getConstifiedMixedStrides (), indices);
173
-
174
- AffineExpr s0;
175
- bindSymbols (rewriter.getContext (), s0);
176
- int64_t scaler = dstBits / srcBits;
177
- OpFoldResult scaledLinearizedIndices =
178
- affine::makeComposedFoldedAffineApply (
179
- rewriter, loc, s0.floorDiv (scaler), {linearizedIndices});
237
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices (
238
+ rewriter, loc, srcBits, adaptor.getIndices (), op.getMemRef ());
239
+
180
240
Value newLoad = rewriter.create <memref::LoadOp>(
181
241
loc, adaptor.getMemref (),
182
- getValueOrCreateConstantIndexOp (rewriter, loc,
183
- scaledLinearizedIndices ));
242
+ getIndicesForLoadOrStore (rewriter, loc, linearizedIndices, srcBits ,
243
+ dstBits ));
184
244
185
245
// Get the offset and shift the bits to the rightmost.
186
246
// Note, currently only the big-endian is supported.
@@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211
271
}
212
272
};
213
273
274
+ // ===----------------------------------------------------------------------===//
275
+ // ConvertMemrefStore
276
+ // ===----------------------------------------------------------------------===//
277
+
278
+ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
279
+ using OpConversionPattern::OpConversionPattern;
280
+
281
+ LogicalResult
282
+ matchAndRewrite (memref::StoreOp op, OpAdaptor adaptor,
283
+ ConversionPatternRewriter &rewriter) const override {
284
+ auto convertedType = adaptor.getMemref ().getType ().cast <MemRefType>();
285
+ int srcBits = op.getMemRefType ().getElementTypeBitWidth ();
286
+ int dstBits = convertedType.getElementTypeBitWidth ();
287
+ auto dstIntegerType = rewriter.getIntegerType (dstBits);
288
+ if (dstBits % srcBits != 0 ) {
289
+ return rewriter.notifyMatchFailure (
290
+ op, " only dstBits % srcBits == 0 supported" );
291
+ }
292
+
293
+ Location loc = op.getLoc ();
294
+ Value extendedInput = rewriter.create <arith::ExtUIOp>(loc, dstIntegerType,
295
+ adaptor.getValue ());
296
+
297
+ // Special case 0-rank memref stores. We can compute the mask at compile
298
+ // time.
299
+ if (convertedType.getRank () == 0 ) {
300
+ // Create mask to clear destination bits
301
+ auto writeMaskValAttr =
302
+ rewriter.getIntegerAttr (dstIntegerType, ~(1 << (srcBits)) - 1 );
303
+ Value writeMask = rewriter.create <arith::ConstantOp>(loc, dstIntegerType,
304
+ writeMaskValAttr);
305
+
306
+ replaceStoreWithAtomics (rewriter, op, writeMask, extendedInput,
307
+ adaptor.getMemref (), ValueRange{});
308
+ return success ();
309
+ }
310
+
311
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices (
312
+ rewriter, loc, srcBits, adaptor.getIndices (), op.getMemRef ());
313
+ Value storeIndices = getIndicesForLoadOrStore (
314
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
315
+ Value bitwidthOffset = getOffsetForBitwidth (loc, linearizedIndices, srcBits,
316
+ dstBits, rewriter);
317
+ Value writeMask = getAtomicWriteMask (loc, linearizedIndices, srcBits,
318
+ dstBits, bitwidthOffset, rewriter);
319
+ // Align the value to write with the destination bits
320
+ Value alignedVal =
321
+ rewriter.create <arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
322
+ replaceStoreWithAtomics (rewriter, op, writeMask, alignedVal,
323
+ adaptor.getMemref (), storeIndices);
324
+ return success ();
325
+ }
326
+ };
327
+
214
328
// ===----------------------------------------------------------------------===//
215
329
// ConvertMemRefSubview
216
330
// ===----------------------------------------------------------------------===//
@@ -292,7 +406,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
292
406
293
407
// Populate `memref.*` conversion patterns.
294
408
patterns.add <ConvertMemRefAlloc, ConvertMemRefLoad,
295
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
409
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore >(
296
410
typeConverter, patterns.getContext ());
297
411
memref::populateResolveExtractStridedMetadataPatterns (patterns);
298
412
}
0 commit comments