@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
15
15
include "mlir/Interfaces/ControlFlowInterfaces.td"
16
16
include "mlir/Interfaces/InferIntRangeInterface.td"
17
17
include "mlir/Interfaces/InferTypeOpInterface.td"
18
+ include "mlir/Interfaces/MemOpInterfaces.td"
18
19
include "mlir/Interfaces/MemorySlotInterfaces.td"
19
20
include "mlir/Interfaces/ShapedOpInterfaces.td"
20
21
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
145
146
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
146
147
Pure,
147
148
ViewLikeOpInterface,
148
- SameOperandsAndResultType
149
+ SameOperandsAndResultType,
150
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
149
151
]> {
150
152
let summary =
151
153
"assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
456
458
def MemRef_CastOp : MemRef_Op<"cast", [
457
459
DeclareOpInterfaceMethods<CastOpInterface>,
458
460
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
459
462
MemRefsNormalizable,
460
463
Pure,
461
464
SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
1194
1197
"memref", "result",
1195
1198
"::llvm::cast<MemRefType>($_self).getElementType()">,
1196
1199
MemRefsNormalizable,
1200
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1197
1201
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
1198
1202
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
1199
1203
let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
1284
1288
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
1285
1289
DeclareOpInterfaceMethods<CastOpInterface>,
1286
1290
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1291
+ MemorySpaceCastOpInterface,
1287
1292
MemRefsNormalizable,
1288
1293
Pure,
1289
1294
SameOperandsAndResultElementType,
@@ -1302,6 +1307,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
1302
1307
1303
1308
If the source and target address spaces are the same, this operation is a noop.
1304
1309
1310
+ Finally, if the target memory-space is the generic/default memory-space,
1311
+ then it is assumed this cast can be bubbled down safely. See the docs of
1312
+ `MemorySpaceCastOpInterface` interface for more details.
1313
+
1305
1314
Example:
1306
1315
1307
1316
```mlir
@@ -1321,6 +1330,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
1321
1330
1322
1331
let extraClassDeclaration = [{
1323
1332
Value getViewSource() { return getSource(); }
1333
+
1334
+ //===------------------------------------------------------------------===//
1335
+ // MemorySpaceCastConsumerOpInterface
1336
+ //===------------------------------------------------------------------===//
1337
+ /// Returns the `source` memref.
1338
+ TypedValue<PtrLikeTypeInterface> getSourcePtr();
1339
+ /// Returns the `dest` memref.
1340
+ TypedValue<PtrLikeTypeInterface> getTargetPtr();
1341
+ /// Returns whether the memory-space cast is valid. Only casts between
1342
+ /// memrefs are considered valid. Further, the `tgt` and `src` should only
1343
+ /// differ on the memory-space parameter of the memref type.
1344
+ bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1345
+ PtrLikeTypeInterface src);
1346
+ /// Clones the operation using a new target type and source value.
1347
+ MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
1348
+ OpBuilder &b, PtrLikeTypeInterface tgt,
1349
+ TypedValue<PtrLikeTypeInterface> src);
1350
+ /// Returns whether the `source` value can be promoted by the
1351
+ /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
1352
+ /// casts the op recognizes as promotable are to the generic memory-space.
1353
+ bool isSourcePromotable();
1324
1354
}];
1325
1355
1326
1356
let hasFolder = 1;
@@ -1376,6 +1406,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
1376
1406
def MemRef_ReinterpretCastOp
1377
1407
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
1378
1408
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1409
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1379
1410
AttrSizedOperandSegments,
1380
1411
MemRefsNormalizable,
1381
1412
Pure,
@@ -1603,6 +1634,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
1603
1634
1604
1635
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
1605
1636
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1637
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1606
1638
Pure,
1607
1639
ViewLikeOpInterface]> {
1608
1640
let summary = "memref reshape operation";
@@ -1701,6 +1733,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1701
1733
1702
1734
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1703
1735
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1736
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1704
1737
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
1705
1738
let summary = "operation to produce a memref with a higher rank.";
1706
1739
let description = [{
@@ -1822,7 +1855,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1822
1855
}
1823
1856
1824
1857
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1825
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1858
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1859
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1860
+ ]> {
1826
1861
let summary = "operation to produce a memref with a smaller rank.";
1827
1862
let description = [{
1828
1863
The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1964,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
1929
1964
"memref", "value",
1930
1965
"::llvm::cast<MemRefType>($_self).getElementType()">,
1931
1966
MemRefsNormalizable,
1967
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1932
1968
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
1933
1969
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
1934
1970
let summary = "store operation";
@@ -2006,6 +2042,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
2006
2042
2007
2043
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
2008
2044
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2045
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2009
2046
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
2010
2047
AttrSizedOperandSegments,
2011
2048
OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2318,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
2281
2318
2282
2319
def MemRef_TransposeOp : MemRef_Op<"transpose", [
2283
2320
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2321
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2284
2322
Pure]>,
2285
2323
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
2286
2324
Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2354,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
2316
2354
2317
2355
def MemRef_ViewOp : MemRef_Op<"view", [
2318
2356
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2357
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2319
2358
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
2320
2359
Pure]> {
2321
2360
let summary = "memref view operation";
@@ -2392,6 +2431,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
2392
2431
//===----------------------------------------------------------------------===//
2393
2432
2394
2433
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2434
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2395
2435
AllTypesMatch<["value", "result"]>,
2396
2436
TypesMatchWith<"value type matches element type of memref",
2397
2437
"memref", "value",
0 commit comments