Skip to content

Commit 077a796

Browse files
fabianmcgjoker-eph
andauthored
[mlir] Implement a memory-space cast bubbling-down transform (#159454)
This commit adds functionality to bubble down memory-space casts operations, allowing consumer operations to use the original memory-space rather than first casting to a different memory space. Changes: - Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations - Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and bubbles down eligible casts - Add implementation for memref and vector operations to handle memory-space cast propagation - Add `bubbleDownCasts` method to relevant operations to support the fusion In particular, in the current implementation only memory-space casts into the default memory-space can be bubbled-down. Example: ```mlir func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> %loaded = memref.load %collapsed[%c0] : memref<16xf32> %added = arith.addf %loaded, %arg2 : f32 memref.store %added, %collapsed[%c0] : memref<16xf32> %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 return %collapsed : memref<16xf32> } // mlir-opt --bubble-down-memory-space-casts func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32> %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1> %1 = arith.addf %0, %arg2 : f32 memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1> %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32 return %memspacecast : memref<16xf32> } ``` --------- Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com> Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
1 parent 0a268f8 commit 077a796

File tree

18 files changed

+939
-7
lines changed

18 files changed

+939
-7
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/InferIntRangeInterface.h"
2020
#include "mlir/Interfaces/InferTypeOpInterface.h"
21+
#include "mlir/Interfaces/MemOpInterfaces.h"
2122
#include "mlir/Interfaces/MemorySlotInterfaces.h"
2223
#include "mlir/Interfaces/ShapedOpInterfaces.h"
2324
#include "mlir/Interfaces/SideEffectInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/InferIntRangeInterface.td"
1717
include "mlir/Interfaces/InferTypeOpInterface.td"
18+
include "mlir/Interfaces/MemOpInterfaces.td"
1819
include "mlir/Interfaces/MemorySlotInterfaces.td"
1920
include "mlir/Interfaces/ShapedOpInterfaces.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
145146
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
146147
Pure,
147148
ViewLikeOpInterface,
148-
SameOperandsAndResultType
149+
SameOperandsAndResultType,
150+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
149151
]> {
150152
let summary =
151153
"assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
456458
def MemRef_CastOp : MemRef_Op<"cast", [
457459
DeclareOpInterfaceMethods<CastOpInterface>,
458460
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
459462
MemRefsNormalizable,
460463
Pure,
461464
SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
11941197
"memref", "result",
11951198
"::llvm::cast<MemRefType>($_self).getElementType()">,
11961199
MemRefsNormalizable,
1200+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
11971201
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
11981202
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
11991203
let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
12841288
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
12851289
DeclareOpInterfaceMethods<CastOpInterface>,
12861290
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1291+
MemorySpaceCastOpInterface,
12871292
MemRefsNormalizable,
12881293
Pure,
12891294
SameOperandsAndResultElementType,
@@ -1302,6 +1307,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
13021307

13031308
If the source and target address spaces are the same, this operation is a noop.
13041309

1310+
Finally, if the target memory-space is the generic/default memory-space,
1311+
then it is assumed this cast can be bubbled down safely. See the docs of
1312+
`MemorySpaceCastOpInterface` interface for more details.
1313+
13051314
Example:
13061315

13071316
```mlir
@@ -1321,6 +1330,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
13211330

13221331
let extraClassDeclaration = [{
13231332
Value getViewSource() { return getSource(); }
1333+
1334+
//===------------------------------------------------------------------===//
1335+
// MemorySpaceCastConsumerOpInterface
1336+
//===------------------------------------------------------------------===//
1337+
/// Returns the `source` memref.
1338+
TypedValue<PtrLikeTypeInterface> getSourcePtr();
1339+
/// Returns the `dest` memref.
1340+
TypedValue<PtrLikeTypeInterface> getTargetPtr();
1341+
/// Returns whether the memory-space cast is valid. Only casts between
1342+
/// memrefs are considered valid. Further, the `tgt` and `src` should only
1343+
/// differ on the memory-space parameter of the memref type.
1344+
bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1345+
PtrLikeTypeInterface src);
1346+
/// Clones the operation using a new target type and source value.
1347+
MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
1348+
OpBuilder &b, PtrLikeTypeInterface tgt,
1349+
TypedValue<PtrLikeTypeInterface> src);
1350+
/// Returns whether the `source` value can be promoted by the
1351+
/// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
1352+
/// casts the op recognizes as promotable are to the generic memory-space.
1353+
bool isSourcePromotable();
13241354
}];
13251355

13261356
let hasFolder = 1;
@@ -1376,6 +1406,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
13761406
def MemRef_ReinterpretCastOp
13771407
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
13781408
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1409+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
13791410
AttrSizedOperandSegments,
13801411
MemRefsNormalizable,
13811412
Pure,
@@ -1603,6 +1634,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
16031634

16041635
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
16051636
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1637+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
16061638
Pure,
16071639
ViewLikeOpInterface]> {
16081640
let summary = "memref reshape operation";
@@ -1701,6 +1733,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17011733

17021734
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17031735
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1736+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
17041737
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17051738
let summary = "operation to produce a memref with a higher rank.";
17061739
let description = [{
@@ -1822,7 +1855,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
18221855
}
18231856

18241857
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1825-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1858+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1859+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1860+
]> {
18261861
let summary = "operation to produce a memref with a smaller rank.";
18271862
let description = [{
18281863
The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1964,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19291964
"memref", "value",
19301965
"::llvm::cast<MemRefType>($_self).getElementType()">,
19311966
MemRefsNormalizable,
1967+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
19321968
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19331969
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
19341970
let summary = "store operation";
@@ -2006,6 +2042,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20062042

20072043
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20082044
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2045+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20092046
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20102047
AttrSizedOperandSegments,
20112048
OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2318,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
22812318

22822319
def MemRef_TransposeOp : MemRef_Op<"transpose", [
22832320
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2321+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
22842322
Pure]>,
22852323
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
22862324
Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2354,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
23162354

23172355
def MemRef_ViewOp : MemRef_Op<"view", [
23182356
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2357+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
23192358
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
23202359
Pure]> {
23212360
let summary = "memref view operation";
@@ -2392,6 +2431,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
23922431
//===----------------------------------------------------------------------===//
23932432

23942433
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2434+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
23952435
AllTypesMatch<["value", "result"]>,
23962436
TypesMatchWith<"value type matches element type of memref",
23972437
"memref", "value",

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2828
#include "mlir/Interfaces/IndexingMapOpInterface.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
30+
#include "mlir/Interfaces/MemOpInterfaces.h"
3031
#include "mlir/Interfaces/SideEffectInterfaces.h"
3132
#include "mlir/Interfaces/VectorInterfaces.h"
3233
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
2424
include "mlir/Interfaces/IndexingMapOpInterface.td"
2525
include "mlir/Interfaces/InferIntRangeInterface.td"
2626
include "mlir/Interfaces/InferTypeOpInterface.td"
27+
include "mlir/Interfaces/MemOpInterfaces.td"
2728
include "mlir/Interfaces/SideEffectInterfaces.td"
2829
include "mlir/Interfaces/VectorInterfaces.td"
2930
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
12461247
DeclareOpInterfaceMethods<MaskableOpInterface>,
12471248
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
12481249
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1250+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12491251
AttrSizedOperandSegments,
12501252
DestinationStyleOpInterface
12511253
]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
14931495
DeclareOpInterfaceMethods<MaskableOpInterface>,
14941496
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
14951497
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1498+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
14961499
AttrSizedOperandSegments,
14971500
DestinationStyleOpInterface
14981501
]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
16491652

16501653
def Vector_LoadOp : Vector_Op<"load", [
16511654
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1655+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
16521656
]> {
16531657
let summary = "reads an n-D slice of memory into an n-D vector";
16541658
let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
17651769

17661770
def Vector_StoreOp : Vector_Op<"store", [
17671771
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1772+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
17681773
]> {
17691774
let summary = "writes an n-D vector to an n-D slice of memory";
17701775
let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
18691874
}
18701875

18711876
def Vector_MaskedLoadOp :
1872-
Vector_Op<"maskedload">,
1877+
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
18731878
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18741879
Variadic<Index>:$indices,
18751880
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
19611966
}
19621967

19631968
def Vector_MaskedStoreOp :
1964-
Vector_Op<"maskedstore">,
1969+
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
19651970
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19661971
Variadic<Index>:$indices,
19671972
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
20412046
def Vector_GatherOp :
20422047
Vector_Op<"gather", [
20432048
DeclareOpInterfaceMethods<MaskableOpInterface>,
2049+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20442050
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
20452051
]>,
20462052
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
21442150
}
21452151

21462152
def Vector_ScatterOp :
2147-
Vector_Op<"scatter">,
2153+
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
21482154
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21492155
Variadic<Index>:$offsets,
21502156
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
22292235
}
22302236

22312237
def Vector_ExpandLoadOp :
2232-
Vector_Op<"expandload">,
2238+
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
22332239
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22342240
Variadic<Index>:$indices,
22352241
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
23172323
}
23182324

23192325
def Vector_CompressStoreOp :
2320-
Vector_Op<"compressstore">,
2326+
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
23212327
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23222328
Variadic<Index>:$indices,
23232329
FixedVectorOfNonZeroRankOf<[I1]>:$mask,

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
88
add_mlir_interface(InferIntRangeInterface)
99
add_mlir_interface(InferTypeOpInterface)
1010
add_mlir_interface(LoopLikeInterface)
11+
add_mlir_interface(MemOpInterfaces)
1112
add_mlir_interface(ParallelCombiningOpInterface)
1213
add_mlir_interface(RuntimeVerifiableOpInterface)
1314
add_mlir_interface(ShapedOpInterfaces)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains declarations of interfaces for operations that interact
10+
// with memory.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
15+
#define MLIR_INTERFACES_MEMOPINTERFACES_H
16+
17+
#include "mlir/IR/OpDefinition.h"
18+
19+
namespace mlir {
20+
namespace detail {
21+
/// Attempt to verify the given memory space cast operation.
22+
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
23+
24+
/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
25+
/// referenced by `operand`. On success, it returns `std::nullopt`. It
26+
/// returns failure if `operand` doesn't reference a
27+
/// `MemorySpaceCastOpInterface` op.
28+
FailureOr<std::optional<SmallVector<Value>>>
29+
bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
30+
} // namespace detail
31+
} // namespace mlir
32+
33+
/// Include the generated interface declarations.
34+
#include "mlir/Interfaces/MemOpInterfaces.h.inc"
35+
36+
#endif // MLIR_INTERFACES_MEMOPINTERFACES_H

0 commit comments

Comments
 (0)