Skip to content

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Sep 17, 2025

This commit adds functionality to bubble down memory-space casts operations, allowing consumer operations to use the original memory-space rather than first casting to a different memory space.

Changes:

  • Introduce MemorySpaceCastOpInterface to handle memory-space cast operations
  • Create a MemorySpaceCastConsumerOpInterface pass that identifies and bubbles down eligible casts
  • Add implementation for memref and vector operations to handle memory-space cast propagation
  • Add bubbleDownCasts method to relevant operations to support the fusion

In particular, in the current implementation only memory-space casts into the default memory-space can be bubbled-down.

Example:

func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}

This commit adds functionality to fuse memory-space casts into consumer operations,
allowing operations to be performed directly on the original memory-space rather
than first casting to a different memory space.

Key changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations
- Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts
- Add implementation for memref and vector operations to handle memory-space cast fusion
- Add fuseCastOperands method to relevant operations to support the fusion

In particular, in the current implementation only memory-space casts into the default
memory-space can be fused.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-vector

Author: Fabian Mora (fabianmcg)

Changes

This commit adds functionality to fuse memory-space casts into consumer operations, allowing operations to be performed directly on the original memory-space rather than first casting to a different memory space.

Changes:

  • Introduce MemorySpaceCastOpInterface to handle memory-space cast operations
  • Create a FuseMemorySpaceCastsIntoConsumers pass that identifies and fuses eligible casts
  • Add implementation for memref and vector operations to handle memory-space cast fusion
  • Add fuseCastOperands method to relevant operations to support the fusion

In particular, in the current implementation only memory-space casts into the default memory-space can be fused.

Example:

func.func @<!-- -->op_with_cast_sequence(%arg0: memref&lt;4x4xf32, 1&gt;, %arg1: index, %arg2: f32) -&gt; memref&lt;16xf32&gt; {
    %memspacecast = memref.memory_space_cast %arg0 : memref&lt;4x4xf32, 1&gt; to memref&lt;4x4xf32&gt;
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref&lt;4x4xf32&gt; into memref&lt;4x2x2xf32&gt;
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref&lt;4x2x2xf32&gt; into memref&lt;16xf32&gt;
    %loaded = memref.load %collapsed[%c0] : memref&lt;16xf32&gt;
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref&lt;16xf32&gt;
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref&lt;16xf32&gt;) -&gt; f32
    return %collapsed : memref&lt;16xf32&gt;
}
// mlir-opt --fuse-memory-space-casts-into-consumers
func.func @<!-- -->op_with_cast_sequence(%arg0: memref&lt;4x4xf32, 1&gt;, %arg1: index, %arg2: f32) -&gt; memref&lt;16xf32&gt; {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref&lt;4x4xf32, 1&gt; into memref&lt;4x2x2xf32, 1&gt;
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref&lt;4x2x2xf32, 1&gt; into memref&lt;16xf32, 1&gt;
    %memspacecast = memref.memory_space_cast %collapse_shape : memref&lt;16xf32, 1&gt; to memref&lt;16xf32&gt;
    %0 = memref.load %collapse_shape[%c0] : memref&lt;16xf32, 1&gt;
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref&lt;16xf32, 1&gt;
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref&lt;16xf32, 1&gt;) -&gt; f32
    return %memspacecast : memref&lt;16xf32&gt;
}

Patch is 64.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159454.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+17-2)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+11-5)
  • (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Interfaces/MemOpInterfaces.h (+37)
  • (added) mlir/include/mlir/Interfaces/MemOpInterfaces.td (+114)
  • (added) mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h (+20)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+40)
  • (modified) mlir/lib/Dialect/MemRef/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+158)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+64)
  • (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
  • (added) mlir/lib/Interfaces/MemOpInterfaces.cpp (+73)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp (+73)
  • (added) mlir/test/Transforms/test-fuse-casts-into-consumers.mlir (+281)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index bdec699eb4ce4..30f33ed2fd1d6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 671cc05e963b4..238a767ac8b73 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure,
       ViewLikeOpInterface,
-      SameOperandsAndResultType
+      SameOperandsAndResultType,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
 def MemRef_CastOp : MemRef_Op<"cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
 def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
       DeclareOpInterfaceMethods<CastOpInterface>,
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
       MemRefsNormalizable,
       Pure,
       SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
 def MemRef_ReinterpretCastOp
   : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       MemRefsNormalizable,
       Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
 
 def MemRef_ReshapeOp: MemRef_Op<"reshape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure,
     ViewLikeOpInterface]>  {
   let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
 }
 
 def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
+  ]> {
   let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
     The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
       DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
 
 def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     AttrSizedOperandSegments,
     OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     Pure]>,
     Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
 
 def MemRef_ViewOp : MemRef_Op<"view", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<ViewLikeOpInterface>,
     Pure]> {
   let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
 //===----------------------------------------------------------------------===//
 
 def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 63410b8bea747..bbf55f5d507e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..93e9bfc78ea75 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
 
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
 
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
   ]> {
   let summary = "writes an n-D vector to an n-D slice of memory";
   let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+  Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+  Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
+    DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
 }
 
 def Vector_ScatterOp :
-  Vector_Op<"scatter">,
+  Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$offsets,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
 }
 
 def Vector_ExpandLoadOp :
-  Vector_Op<"expandload">,
+  Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
 }
 
 def Vector_CompressStoreOp :
-  Vector_Op<"compressstore">,
+  Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 2add220fdfb7c..a5feb592045c0 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(MemOpInterfaces)
 add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
new file mode 100644
index 0000000000000..cc9f4c6b3882e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h
@@ -0,0 +1,37 @@
+//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of interfaces for operations that interact
+// with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
+#define MLIR_INTERFACES_MEMOPINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Attempt to verify the given memory space cast operation.
+LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
+
+/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
+/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
+/// true. It returns failure if `operand` doesn't reference a
+/// `MemorySpaceCastOpInterface` op.
+FailureOr<SmallVector<Value>>
+fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
+                               bool &modifiedInPlace);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/MemOpInterfaces.h.inc"
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
new file mode 100644
index 0000000000000..0b8ba19171fb7
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td
@@ -0,0 +1,114 @@
+//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces for operations that interact with memory.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
+#define MLIR_INTERFACES_MEMOPINTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def FuseMemorySpaceCastConsumerOpInterface :
+    OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
+  let description = [{
+    An interface to fuse memory-space cast operands into a consumer operation.
+    It is the responsibility of the interface to determine which casts can be
+    fused into the operation.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Attempt to fuse the incoming cast-like operands. Returns `success`
+        and any new results on fusion success, otherwise it returns failure.
+        If new results are produced, these must be compatible with the original
+        operation results.
+
+        The `modifiedInPlace` parameter indicates whether the operation was
+        modified in place. If `false` and the fusion succeeded, then the
+        interface guarantees it is valid to erase the original operation.
+        If `true`, then the interface must guarantee no operations were created
+        by the method, and that no further IR modification is necessary. It is
+        considered an error if `modifiedInPlace` is true and the fusion failed.
+
+        Any implementations of this method must not erase/replace the original
+        operation, instead it is the caller responsibility to erase or replace
+        the op with the results provided by the method.
+
+        Finally, any implementations of this method have to guarantee that the
+        IR remains valid at all times.
+      }],
+      "::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
+      (ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
+    >,
+  ];
+}
+
+def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
+  let description = [{
+    An interface for operations that perform memory-space casts. This
+    interface assumes that the cast operation is `pure`.
+
+    These operations expect to have a well-defined ptr-like operand, and
+    a well-defined target ptr-like result.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<[{
+        Returns the source ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>",  "getSourcePtr"
+    >,
+    InterfaceMethod<[{
+        Returns the target ptr-like value.
+      }],
+      "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
+    >,
+    InterfaceMethod<[{
+        Returns whether the memory space cast specified by `tgt` and `src`
+        is supported.
+      }],
+      "bool", "isValidMemorySpaceCast",
+      (ins "::mlir::PtrLikeTypeInterface":$tgt,
+           "::mlir::PtrLikeTypeInterface":$src)
+    >,
+    InterfaceMethod<[{
+        Clones the memory space cast op with the given source and target type.
+      }],
+      "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
+      (ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
+           "::mlir::Value":$src)
+    >,
+    InterfaceMethod<[{
+        Returns whether the cast allows to be fused.
+      }],
+      "bool", "isFusableMemorySpaceCast"
+    >
+  ];
+  let verify = [{
+    return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
+  }];
+  let dependentTraits = [Pure];
+  let extraClassDeclaration = [{
+    /// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
+    /// is produced by a `MemorySpaceCastOpInterface` op, and
+    /// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
+    static ::mlir::MemorySpaceCastOpInterface
+    getIfFusableCast(::mlir::Value value) {
+      auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
+        value.getDefiningOp());
+      if (!op || !op.isFusableMemorySpaceCast())
+        return nullptr;
+      return op;
+    }
+  }];
+}
+
+#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
diff --git a/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
new file mode 100644
index 0000000000000..9333f92a10289
--- /dev/null
+++ b/mlir/include/mlir/Transforms/FuseMemorySpaceCastsIntoConsumers.h
@@ -0,0 +1,20 @@
+//===-- FuseMemorySpaceCastsIntoConsumers.h - Cast fusion patterns -C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+#define MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
+
+namespace mlir {
+class RewritePatternSet;
+/// Collect a set of patterns to fuse memory-space cast operations into
+/// consumers.
+void populateFuseMemorySpaceCastIntoConsumersPatterns(
+    RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_FUSEMEMORYSPACECASTSINTOCONSUMERS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..610a9671fede8 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_SYMBOLPRIVATIZE
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_FUSEMEMORYSPACECASTSINTO...
[truncated]

@fabianmcg fabianmcg requested a review from Copilot September 17, 2025 21:42
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements memory-space cast operand fusion into consumer operations for MLIR, allowing operations to work directly on the original memory space rather than casting to a different space first.

  • Introduces new interfaces for handling memory-space cast operations and their fusion
  • Creates a transformation pass that identifies and fuses eligible cast operations into their consumers
  • Adds fusion support for memref and vector operations that consume cast operands

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/test/Transforms/test-fuse-casts-into-consumers.mlir Test file demonstrating the fusion transformation with various memref and vector operations
mlir/lib/Transforms/FuseMemorySpaceCastsIntoConsumers.cpp Main pass implementation that applies fusion patterns
mlir/lib/Interfaces/MemOpInterfaces.cpp Interface implementation with verification and fusion helper functions
mlir/include/mlir/Interfaces/MemOpInterfaces.td Interface definitions for memory-space cast operations and their consumers
mlir/include/mlir/Interfaces/MemOpInterfaces.h Header declarations for the memory operation interfaces
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Fusion implementation for memref operations
mlir/lib/Dialect/Vector/IR/VectorOps.cpp Fusion implementation for vector operations
Various CMakeLists.txt files Build system updates to include new interface library

@kuhar kuhar removed their request for review September 18, 2025 14:07
@Groverkss Groverkss requested a review from krzysz00 September 19, 2025 09:59
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did an initial review, mostly looks fine to me. Can we find a smaller name for the interface ? Otherwise looks all good, I'll do an indepth review later.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo comments on tests.


using namespace mlir;

LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually write invalid tests for these error messages

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no operations upstream that would break any of the conditions being verified. Either we move the verified conditions to the assumptions in the interface description, or we add several test operations to check failures.

@joker-eph
Copy link
Collaborator

This commit adds functionality to fuse memory-space casts into consumer operations,

I don't see a "fusion" on your example: the before/after both have the same number of operations, with the cast being just "bubbled-down" here?

@fabianmcg
Copy link
Contributor Author

I don't see a "fusion" on your example: the before/after both have the same number of operations, with the cast being just "bubbled-down" here?

Yeah, I've been thinking a better name for the pass is propagate memory space cast. For the interface, what about ConsumeMemorySpaceOperandsCastOpInterface ?

@fabianmcg fabianmcg changed the title [mlir] Implement memory-space cast operand fusion into consumers [mlir] Implement a memory-space cast bubbling-down transform Sep 19, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 1 comment.

fabianmcg and others added 2 commits September 19, 2025 09:04
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@krzysz00
Copy link
Contributor

Re naming: we could follow LLVM here and call it -memref-infer-memory-spaces

@krzysz00
Copy link
Contributor

As a general comment, the notion of a lossless memory cast is something that'll need an attribute interface on memory spaces, not an op interface.

And probably DLTI

So ... what if we just made this infer-memory-spaces to rewrite out as many casts to the null memory space as you can get your hands on?

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Sep 19, 2025

Re naming: we could follow LLVM here and call it -memref-infer-memory-spaces

This is not exclusive to memrefs.

As a general comment, the notion of a lossless memory cast is something that'll need an attribute interface on memory spaces, not an op interface.

See my other comment, that was a poorly chosen name and I'll change.

However, this is an op interface as it is the cast op the one that determines whether it can be propogated or not without altering program semantics. For example, a amdgpu.fat_raw_buffer_cast will never allow propagation.

So ... what if we just made this infer-memory-spaces to rewrite out as many casts to the null memory space as you can get your hands on?

ptr doesn't allow null memory space.

@krzysz00
Copy link
Contributor

Noted on ptr - so what you'd really want, I think, is a notion of which memory space(s) you're trying to eliminate as a pass parameter.

But yeah, so while the fat raw buffer cast widget can't be freely rewriten away like you said, I'm not sure that memspacecast can be either without more information? I forget what assumptions we've made

fabianmcg and others added 3 commits September 22, 2025 06:37
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of my previous comments about the generality or lack thereof of this pass were based on me missing the fact that memref.memory_space_cast's answer to "is promotable" is "does this cast to the null/default/generic memory space" and not what I thought it was, which was "yes".

With that context taken into account, this pass makes a lot of sense and is approved

@joker-eph
Copy link
Collaborator

Many of my previous comments about the generality or lack thereof of this pass were based on me missing the fact that memref.memory_space_cast's answer to "is promotable" is "does this cast to the null/default/generic memory space" and not what I thought it was, which was "yes".

With that context taken into account, this pass makes a lot of sense and is approved

That means the doc needs to be improved on all this (both the pass description and the op/attrs likely)

@fabianmcg
Copy link
Contributor Author

Many of my previous comments about the generality or lack thereof of this pass were based on me missing the fact that memref.memory_space_cast's answer to "is promotable" is "does this cast to the null/default/generic memory space" and not what I thought it was, which was "yes".
With that context taken into account, this pass makes a lot of sense and is approved

That means the doc needs to be improved on all this (both the pass description and the op/attrs likely)

I updated the docs, @krzysz00 is it clearer now?

@fabianmcg fabianmcg merged commit 077a796 into llvm:main Sep 24, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…9454)

This commit adds functionality to bubble down memory-space casts
operations, allowing consumer operations to use the original
memory-space rather than first casting to a different memory space.

Changes:
- Introduce `MemorySpaceCastOpInterface` to handle memory-space cast
operations
- Create a `MemorySpaceCastConsumerOpInterface` pass that identifies and
bubbles down eligible casts
- Add implementation for memref and vector operations to handle
memory-space cast propagation
- Add `bubbleDownCasts` method to relevant operations to support the
fusion

In particular, in the current implementation only memory-space casts
into the default memory-space can be bubbled-down.

Example:

```mlir
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32>
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index
    %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32>
    %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32>
    %loaded = memref.load %collapsed[%c0] : memref<16xf32>
    %added = arith.addf %loaded, %arg2 : f32
    memref.store %added, %collapsed[%c0] : memref<16xf32>
    %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32
    return %collapsed : memref<16xf32>
}
// mlir-opt --bubble-down-memory-space-casts
func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> {
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
    %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
    %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32>
    %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1>
    %1 = arith.addf %0, %arg2 : f32
    memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1>
    %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32
    return %memspacecast : memref<16xf32>
}
```

---------

Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants