Skip to content

Commit f1c15b5

Browse files
committed
[mlir] Add subbyte emulation support for memref.store.
This adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte.
1 parent c1146f3 commit f1c15b5

File tree

2 files changed

+308
-25
lines changed

2 files changed

+308
-25
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 139 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/OpDefinition.h"
2023
#include "mlir/Support/MathExtras.h"
2124
#include "mlir/Transforms/DialectConversion.h"
2225
#include "llvm/Support/FormatVariadic.h"
@@ -29,6 +32,26 @@ using namespace mlir;
2932
// Utility functions
3033
//===----------------------------------------------------------------------===//
3134

35+
/// Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first
36+
/// memref::AtomicRMWOp sets the destination bits to all zero to prepare the
37+
/// destination byte to be written to. The second memref::AtomicRMWOp does the
38+
/// writing of the value to store, using an `ori` type operation. The value
39+
/// to store and the write mask should both have the destination type bitwidth,
40+
/// and the bits of the value to store should be all zero except for the bits
41+
/// aligned with the store destination.
42+
static void replaceStoreWithAtomics(ConversionPatternRewriter &rewriter,
43+
memref::StoreOp op, Value writeMask,
44+
Value storeVal, Value memref,
45+
ValueRange storeIndices) {
46+
// Clear destination bits
47+
rewriter.create<memref::AtomicRMWOp>(op.getLoc(), arith::AtomicRMWKind::andi,
48+
writeMask, memref, storeIndices);
49+
// Write srcs bits to destination
50+
rewriter.create<memref::AtomicRMWOp>(op->getLoc(), arith::AtomicRMWKind::ori,
51+
storeVal, memref, storeIndices);
52+
rewriter.eraseOp(op);
53+
}
54+
3255
/// When data is loaded/stored in `targetBits` granularity, but is used in
3356
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
3457
/// treated as an array of elements of width `sourceBits`.
@@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
4366
AffineExpr s0;
4467
bindSymbols(builder.getContext(), s0);
4568
int scaleFactor = targetBits / sourceBits;
46-
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
47-
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
69+
AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
70+
OpFoldResult offsetVal =
71+
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
4872
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
4973
IntegerType dstType = builder.getIntegerType(targetBits);
5074
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
5175
}
5276

77+
/// When writing a subbyte size, writing needs to happen atomically in case of
78+
/// another write happening on the same byte at the same time. To do the write,
79+
/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
80+
/// store. This function returns the appropriate mask for clearing these bits.
81+
static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
82+
int64_t srcBits, int64_t dstBits,
83+
Value bitwidthOffset, OpBuilder &builder) {
84+
auto dstIntegerType = builder.getIntegerType(dstBits);
85+
auto maskRightAlignedAttr =
86+
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
87+
Value maskRightAligned =
88+
builder
89+
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
90+
.getResult();
91+
Value writeMaskInverse =
92+
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
93+
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
94+
Value flipVal =
95+
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
96+
.getResult();
97+
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
98+
}
99+
100+
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
101+
/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
102+
/// the returned index has the granularity of `dstBits`
103+
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
104+
OpFoldResult linearizedIndex,
105+
int64_t srcBits, int64_t dstBits) {
106+
AffineExpr s0;
107+
bindSymbols(builder.getContext(), s0);
108+
int64_t scaler = dstBits / srcBits;
109+
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
110+
builder, loc, s0.floorDiv(scaler), {linearizedIndex});
111+
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
112+
}
113+
114+
static OpFoldResult
115+
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
116+
const SmallVector<OpFoldResult> &indices,
117+
Value memref) {
118+
auto stridedMetadata =
119+
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
120+
OpFoldResult linearizedIndices;
121+
std::tie(std::ignore, linearizedIndices) =
122+
memref::getLinearizedMemRefOffsetAndSize(
123+
builder, loc, srcBits, srcBits,
124+
stridedMetadata.getConstifiedMixedOffset(),
125+
stridedMetadata.getConstifiedMixedSizes(),
126+
stridedMetadata.getConstifiedMixedStrides(), indices);
127+
return linearizedIndices;
128+
}
129+
53130
namespace {
54131

55132
//===----------------------------------------------------------------------===//
@@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
155232
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
156233
ValueRange{});
157234
} else {
158-
SmallVector<OpFoldResult> indices =
159-
getAsOpFoldResult(adaptor.getIndices());
160-
161-
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
162-
loc, op.getMemRef());
163-
164235
// Linearize the indices of the original load instruction. Do not account
165236
// for the scaling yet. This will be accounted for later.
166-
OpFoldResult linearizedIndices;
167-
std::tie(std::ignore, linearizedIndices) =
168-
memref::getLinearizedMemRefOffsetAndSize(
169-
rewriter, loc, srcBits, srcBits,
170-
stridedMetadata.getConstifiedMixedOffset(),
171-
stridedMetadata.getConstifiedMixedSizes(),
172-
stridedMetadata.getConstifiedMixedStrides(), indices);
173-
174-
AffineExpr s0;
175-
bindSymbols(rewriter.getContext(), s0);
176-
int64_t scaler = dstBits / srcBits;
177-
OpFoldResult scaledLinearizedIndices =
178-
affine::makeComposedFoldedAffineApply(
179-
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
237+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
238+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
239+
180240
Value newLoad = rewriter.create<memref::LoadOp>(
181241
loc, adaptor.getMemref(),
182-
getValueOrCreateConstantIndexOp(rewriter, loc,
183-
scaledLinearizedIndices));
242+
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
243+
dstBits));
184244

185245
// Get the offset and shift the bits to the rightmost.
186246
// Note, currently only the big-endian is supported.
@@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211271
}
212272
};
213273

274+
//===----------------------------------------------------------------------===//
275+
// ConvertMemrefStore
276+
//===----------------------------------------------------------------------===//
277+
278+
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
279+
using OpConversionPattern::OpConversionPattern;
280+
281+
LogicalResult
282+
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
283+
ConversionPatternRewriter &rewriter) const override {
284+
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
285+
int srcBits = op.getMemRefType().getElementTypeBitWidth();
286+
int dstBits = convertedType.getElementTypeBitWidth();
287+
auto dstIntegerType = rewriter.getIntegerType(dstBits);
288+
if (dstBits % srcBits != 0) {
289+
return rewriter.notifyMatchFailure(
290+
op, "only dstBits % srcBits == 0 supported");
291+
}
292+
293+
Location loc = op.getLoc();
294+
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
295+
adaptor.getValue());
296+
297+
// Special case 0-rank memref stores. We can compute the mask at compile
298+
// time.
299+
if (convertedType.getRank() == 0) {
300+
// Create mask to clear destination bits
301+
auto writeMaskValAttr =
302+
rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
303+
Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
304+
writeMaskValAttr);
305+
306+
replaceStoreWithAtomics(rewriter, op, writeMask, extendedInput,
307+
adaptor.getMemref(), ValueRange{});
308+
return success();
309+
}
310+
311+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
312+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
313+
Value storeIndices = getIndicesForLoadOrStore(
314+
rewriter, loc, linearizedIndices, srcBits, dstBits);
315+
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
316+
dstBits, rewriter);
317+
Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
318+
dstBits, bitwidthOffset, rewriter);
319+
// Align the value to write with the destination bits
320+
Value alignedVal =
321+
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
322+
replaceStoreWithAtomics(rewriter, op, writeMask, alignedVal,
323+
adaptor.getMemref(), storeIndices);
324+
return success();
325+
}
326+
};
327+
214328
//===----------------------------------------------------------------------===//
215329
// ConvertMemRefSubview
216330
//===----------------------------------------------------------------------===//
@@ -292,7 +406,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
292406

293407
// Populate `memref.*` conversion patterns.
294408
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
295-
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
409+
ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
296410
typeConverter, patterns.getContext());
297411
memref::populateResolveExtractStridedMetadataPatterns(patterns);
298412
}

0 commit comments

Comments
 (0)