Skip to content

Commit 47034c4

Browse files
bondhugulatensorflower-gardener
authored andcommitted
Introduce prefetch op: affine -> std -> llvm intrinsic
Introduce affine.prefetch: op to prefetch using a multi-dimensional subscript on a memref; similar to affine.load but has no effect on semantics, but only on performance. Provide lowering through std.prefetch, llvm.prefetch and map to llvm's prefetch instrinsic. All attributes reflected through the lowering - locality hint, rw, and instr/data cache. affine.prefetch %0[%i, %j + 5], false, 3, true : memref<400x400xi32> Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#225 COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#225 from bondhugula:prefetch 4c3b4e93bc64d9a5719504e6d6e1657818a2ead0 PiperOrigin-RevId: 286212997
1 parent 4562e38 commit 47034c4

File tree

21 files changed

+560
-30
lines changed

21 files changed

+560
-30
lines changed

mlir/g3doc/OpDefinitions.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,19 +243,21 @@ like `"0.5f"`, and an integer array default value should be specified as like
243243
`Confined` is provided as a general mechanism to help modelling further
244244
constraints on attributes beyond the ones brought by value types. You can use
245245
`Confined` to compose complex constraints out of more primitive ones. For
246-
example, a 32-bit integer attribute whose minimal value must be 10 can be
246+
example, a 32-bit integer attribute whose minimum value must be 10 can be
247247
expressed as `Confined<I32Attr, [IntMinValue<10>]>`.
248248

249249
Right now, the following primitive constraints are supported:
250250

251-
* `IntMinValue<N>`: Specifying an integer attribute to be greater than or equal
252-
to `N`
253-
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
254-
elements
255-
* `IntArrayNthElemEq<I, N>`: Specifying an integer array attribute's `I`-th
256-
element to be equal to `N`
257-
* `IntArrayNthElemMinValue<I, N>`: Specifying an integer array attribute's
258-
`I`-th element to be greater than or equal to `N`
251+
* `IntMinValue<N>`: Specifying an integer attribute to be greater than or
252+
equal to `N`
253+
* `IntMaxValue<N>`: Specifying an integer attribute to be less than or equal
254+
to `N`
255+
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
256+
elements
257+
* `IntArrayNthElemEq<I, N>`: Specifying an integer array attribute's `I`-th
258+
element to be equal to `N`
259+
* `IntArrayNthElemMinValue<I, N>`: Specifying an integer array attribute's
260+
`I`-th element to be greater than or equal to `N`
259261

260262
TODO: Design and implement more primitive constraints
261263

mlir/include/mlir/Dialect/AffineOps/AffineOps.td

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,81 @@ def AffineMinOp : Affine_Op<"min"> {
261261
let hasFolder = 1;
262262
}
263263

264+
def AffinePrefetchOp : Affine_Op<"prefetch"> {
265+
let summary = "affine prefetch operation";
266+
let description = [{
267+
The "affine.prefetch" op prefetches data from a memref location described
268+
with an affine subscript similar to affine.load, and has three attributes:
269+
a read/write specifier, a locality hint, and a cache type specifier as shown
270+
below:
271+
272+
affine.prefetch %0[%i, %j + 5], read, locality<3>, data
273+
: memref<400x400xi32>
274+
275+
The read/write specifier is either 'read' or 'write', the locality hint
276+
specifier ranges from locality<0> (no locality) to locality<3> (extremely
277+
local keep in cache). The cache type specifier is either 'data' or 'instr'
278+
and specifies whether the prefetch is performed on data cache or on
279+
instruction cache.
280+
}];
281+
282+
let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices,
283+
BoolAttr:$isWrite,
284+
Confined<I32Attr, [IntMinValue<0>,
285+
IntMaxValue<3>]>:$localityHint,
286+
BoolAttr:$isDataCache);
287+
288+
let builders = [OpBuilder<
289+
"Builder *builder, OperationState &result, Value *memref,"
290+
"AffineMap map, ArrayRef<Value *> mapOperands, bool isWrite,"
291+
"unsigned localityHint, bool isDataCache",
292+
[{
293+
assert(map.getNumInputs() == mapOperands.size()
294+
&& "inconsistent index info");
295+
auto localityHintAttr = builder->getI32IntegerAttr(localityHint);
296+
auto isWriteAttr = builder->getBoolAttr(isWrite);
297+
auto isDataCacheAttr = builder->getBoolAttr(isDataCache);
298+
result.addOperands(memref);
299+
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
300+
result.addOperands(mapOperands);
301+
result.addAttribute(getLocalityHintAttrName(), localityHintAttr);
302+
result.addAttribute(getIsWriteAttrName(), isWriteAttr);
303+
result.addAttribute(getIsDataCacheAttrName(), isDataCacheAttr);
304+
}]>];
305+
306+
let extraClassDeclaration = [{
307+
MemRefType getMemRefType() {
308+
return memref()->getType().cast<MemRefType>();
309+
}
310+
311+
/// Returns the affine map used to index the memref for this operation.
312+
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
313+
AffineMapAttr getAffineMapAttr() {
314+
return getAttr(getMapAttrName()).cast<AffineMapAttr>();
315+
}
316+
317+
/// Returns the AffineMapAttr associated with 'memref'.
318+
NamedAttribute getAffineMapAttrForMemRef(Value *mref) {
319+
assert(mref == memref());
320+
return {Identifier::get(getMapAttrName(), getContext()),
321+
getAffineMapAttr()};
322+
}
323+
324+
/// Get affine map operands.
325+
operand_range getMapOperands() {
326+
return {operand_begin() + 1, operand_end()};
327+
}
328+
329+
static StringRef getMapAttrName() { return "map"; }
330+
static StringRef getLocalityHintAttrName() { return "localityHint"; }
331+
static StringRef getIsWriteAttrName() { return "isWrite"; }
332+
static StringRef getIsDataCacheAttrName() { return "isDataCache"; }
333+
}];
334+
335+
let hasCanonicalizer = 1;
336+
let hasFolder = 1;
337+
}
338+
264339
def AffineTerminatorOp :
265340
Affine_Op<"terminator", [Terminator]> {
266341
let summary = "affine terminator operation";

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,17 @@ def LLVM_FMulAddOp : LLVM_Op<"intr.fmuladd", [NoSideEffect]>,
674674
}];
675675
}
676676

677+
def LLVM_Prefetch : LLVM_ZeroResultOp<"intr.prefetch">,
678+
Arguments<(ins LLVM_Type:$addr, LLVM_Type:$rw,
679+
LLVM_Type:$hint, LLVM_Type:$cache)> {
680+
let llvmBuilder = [{
681+
llvm::Module *module = builder.GetInsertBlock()->getModule();
682+
llvm::Function *fn = llvm::Intrinsic::getDeclaration(
683+
module, llvm::Intrinsic::prefetch, $addr->getType());
684+
builder.CreateCall(fn, {$addr, $rw, $hint, $cache});
685+
}];
686+
}
687+
677688
def LLVM_ExpOp : LLVM_Op<"intr.exp", [NoSideEffect]>,
678689
Arguments<(ins LLVM_Type:$in)>,
679690
Results<(outs LLVM_Type:$res)> {

mlir/include/mlir/Dialect/StandardOps/Ops.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,55 @@ def OrOp : IntArithmeticOp<"or", [Commutative]> {
928928
let hasFolder = 1;
929929
}
930930

931+
def PrefetchOp : Std_Op<"prefetch"> {
932+
let summary = "prefetch operation";
933+
let description = [{
934+
The "prefetch" op prefetches data from a memref location described with
935+
subscript indices similar to std.load, and with three attributes: a
936+
read/write specifier, a locality hint, and a cache type specifier as shown
937+
below:
938+
939+
prefetch %0[%i, %j], read, locality<3>, data : memref<400x400xi32>
940+
941+
The read/write specifier is either 'read' or 'write', the locality hint
942+
ranges from locality<0> (no locality) to locality<3> (extremely local keep
943+
in cache). The cache type specifier is either 'data' or 'instr'
944+
and specifies whether the prefetch is performed on data cache or on
945+
instruction cache.
946+
}];
947+
948+
let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices,
949+
BoolAttr:$isWrite,
950+
Confined<I32Attr, [IntMinValue<0>,
951+
IntMaxValue<3>]>:$localityHint,
952+
BoolAttr:$isDataCache);
953+
954+
let builders = [OpBuilder<
955+
"Builder *builder, OperationState &result, Value *memref,"
956+
"ArrayRef<Value *> indices, bool isWrite, unsigned hint, bool isData",
957+
[{
958+
auto hintAttr = builder->getI32IntegerAttr(hint);
959+
auto isWriteAttr = builder->getBoolAttr(isWrite);
960+
auto isDataCacheAttr = builder->getBoolAttr(isData);
961+
result.addOperands(memref);
962+
result.addOperands(indices);
963+
result.addAttribute("localityHint", hintAttr);
964+
result.addAttribute("isWrite", isWriteAttr);
965+
result.addAttribute("isDataCache", isDataCacheAttr);
966+
}]>];
967+
968+
let extraClassDeclaration = [{
969+
MemRefType getMemRefType() {
970+
return memref()->getType().cast<MemRefType>();
971+
}
972+
static StringRef getLocalityHintAttrName() { return "localityHint"; }
973+
static StringRef getIsWriteAttrName() { return "isWrite"; }
974+
static StringRef getIsDataCacheAttrName() { return "isDataCache"; }
975+
}];
976+
977+
let hasFolder = 1;
978+
}
979+
931980
def RankOp : Std_Op<"rank", [NoSideEffect]> {
932981
let summary = "rank operation";
933982
let description = [{

mlir/include/mlir/IR/OpBase.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,11 @@ class AllAttrConstraintsOf<list<AttrConstraint> constraints> : AttrConstraint<
12421242

12431243
class IntMinValue<int n> : AttrConstraint<
12441244
CPred<"$_self.cast<IntegerAttr>().getInt() >= " # n>,
1245-
"whose minimal value is " # n>;
1245+
"whose minimum value is " # n>;
1246+
1247+
class IntMaxValue<int n> : AttrConstraint<
1248+
CPred<"$_self.cast<IntegerAttr>().getInt() <= " # n>,
1249+
"whose maximum value is " # n>;
12461250

12471251
class ArrayMinCount<int n> : AttrConstraint<
12481252
CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ class OpAsmParser {
284284
/// Parse a `=` token.
285285
virtual ParseResult parseEqual() = 0;
286286

287+
/// Parse a '<' token.
288+
virtual ParseResult parseLess() = 0;
289+
290+
/// Parse a '>' token.
291+
virtual ParseResult parseGreater() = 0;
292+
287293
/// Parse a given keyword.
288294
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
289295
auto loc = getCurrentLocation();

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,37 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
405405
PatternRewriter &rewriter) const override {
406406
// Expand affine map from 'affineLoadOp'.
407407
SmallVector<Value *, 8> indices(op.getMapOperands());
408-
auto maybeExpandedMap =
408+
auto resultOperands =
409409
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
410-
if (!maybeExpandedMap)
410+
if (!resultOperands)
411411
return matchFailure();
412412

413413
// Build std.load memref[expandedMap.results].
414-
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *maybeExpandedMap);
414+
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
415+
return matchSuccess();
416+
}
417+
};
418+
419+
// Apply the affine map from an 'affine.prefetch' operation to its operands, and
420+
// feed the results to a newly created 'std.prefetch' operation (which replaces
421+
// the original 'affine.prefetch').
422+
class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
423+
public:
424+
using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
425+
426+
PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
427+
PatternRewriter &rewriter) const override {
428+
// Expand affine map from 'affinePrefetchOp'.
429+
SmallVector<Value *, 8> indices(op.getMapOperands());
430+
auto resultOperands =
431+
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
432+
if (!resultOperands)
433+
return matchFailure();
434+
435+
// Build std.prefetch memref[expandedMap.results].
436+
rewriter.replaceOpWithNewOp<PrefetchOp>(
437+
op, op.memref(), *resultOperands, op.isWrite(),
438+
op.localityHint().getZExtValue(), op.isDataCache());
415439
return matchSuccess();
416440
}
417441
};
@@ -506,11 +530,10 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
506530

507531
void mlir::populateAffineToStdConversionPatterns(
508532
OwningRewritePatternList &patterns, MLIRContext *ctx) {
509-
patterns
510-
.insert<AffineApplyLowering, AffineDmaStartLowering,
511-
AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering,
512-
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(
513-
ctx);
533+
patterns.insert<
534+
AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering,
535+
AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering,
536+
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx);
514537
}
515538

516539
namespace {

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,6 +1462,39 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
14621462
}
14631463
};
14641464

1465+
// The prefetch operation is lowered in a way similar to the load operation
1466+
// except that the llvm.prefetch operation is used for replacement.
1467+
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
1468+
using Base::Base;
1469+
1470+
PatternMatchResult
1471+
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
1472+
ConversionPatternRewriter &rewriter) const override {
1473+
auto prefetchOp = cast<PrefetchOp>(op);
1474+
OperandAdaptor<PrefetchOp> transformed(operands);
1475+
auto type = prefetchOp.getMemRefType();
1476+
1477+
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
1478+
transformed.indices(), rewriter, getModule());
1479+
1480+
// Replace with llvm.prefetch.
1481+
auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32));
1482+
auto isWrite = rewriter.create<LLVM::ConstantOp>(
1483+
op->getLoc(), llvmI32Type,
1484+
rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
1485+
auto localityHint = rewriter.create<LLVM::ConstantOp>(
1486+
op->getLoc(), llvmI32Type,
1487+
rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue()));
1488+
auto isData = rewriter.create<LLVM::ConstantOp>(
1489+
op->getLoc(), llvmI32Type,
1490+
rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
1491+
1492+
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
1493+
localityHint, isData);
1494+
return matchSuccess();
1495+
}
1496+
};
1497+
14651498
// The lowering of index_cast becomes an integer conversion since index becomes
14661499
// an integer. If the bit width of the source and target integer types is the
14671500
// same, just erase the cast. If the target type is wider, sign-extend the
@@ -2041,6 +2074,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
20412074
MulFOpLowering,
20422075
MulIOpLowering,
20432076
OrOpLowering,
2077+
PrefetchOpLowering,
20442078
RemFOpLowering,
20452079
RemISOpLowering,
20462080
RemIUOpLowering,

0 commit comments

Comments
 (0)