Skip to content

Conversation

@kdmitry1
Copy link

memref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic.

Example:

  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
  %c0 = arith.constant 0 : index
  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4]  : memref<?x4xf32> into memref<?x1x4xf32>

is converted to:

  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
  %cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: None (kdmitry1)

Changes

memref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic.

Example:

  %0 = memref.cast %arg0 : memref&lt;8x4xf32&gt; to memref&lt;?x4xf32&gt;
  %c0 = arith.constant 0 : index
  %dim0 = memref.dim %0, %c0 : memref&lt;?x4xf32&gt;
  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4]  : memref&lt;?x4xf32&gt; into memref&lt;?x1x4xf32&gt;

is converted to:

  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref&lt;8x4xf32&gt; into memref&lt;8x1x4xf32&gt;
  %cast = memref.cast %expand_shape : memref&lt;8x1x4xf32&gt; to memref&lt;?x1x4xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/170037.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+79-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+84)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
   return success();
 }
 
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getSrc().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    auto originalOutputShape = op.getMixedOutputShape();
+    auto newOutputShape = originalOutputShape;
+    SmallVector<int64_t> newOutputShapeSizes;
+    SmallVector<Value> newOperands;
+
+    // Convert output shape dims from dynamic to static where possible.
+    for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+      auto dimVal = dimSize.dyn_cast<Value>();
+      if (!dimVal) {
+        newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+        continue;
+      }
+
+      auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+      if (!constOp) {
+        newOperands.push_back(dimVal);
+        newOutputShapeSizes.push_back(ShapedType::kDynamic);
+        continue;
+      }
+
+      newOutputShape[dimIdx] = constOp.getValue();
+      newOutputShapeSizes.push_back(
+          getConstantIntValue(constOp.getValue()).value());
+    }
+
+    if (newOperands.size() == op->getNumOperands())
+      return rewriter.notifyMatchFailure(
+          op, "no static-to-dynamic conversions found");
+
+    auto castSource = cast.getSource();
+    auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+    auto reassociationIndices = op.getReassociationIndices();
+    for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+      int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+      auto newOutputShapeSizesSlice =
+          ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+      int64_t newOutputDynCount =
+          llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+      if (castSourceDynCount != newOutputDynCount)
+        return rewriter.notifyMatchFailure(
+            op, "folding cast will result in changing dynamicity in "
+                "reassociation group");
+    }
+
+    auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+        castSourceType, newOutputShapeSizes, reassociationIndices);
+
+    if (failed(newResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "could not compute new expanded type after folding cast");
+
+    if (*newResultTypeOrFailure == op.getResultType()) {
+      rewriter.modifyOpInPlace(
+          op, [&]() { op.getSrcMutable().assign(castSource); });
+    } else {
+      Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+                                          *newResultTypeOrFailure, castSource,
+                                          reassociationIndices, newOutputShape);
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ExpandShapeOpMemRefCastFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // -----
 
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %c4 = arith.constant 4 : index
+  %dim_ext = arith.divui %dim0 , %c4: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+      : memref<?x?xf32> into memref<1x?x1x?xf32>
+  %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+      : memref<?x?xf32> into memref<?x?x?x?xf32>
+  %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+      : memref<8x4xf32> into memref<2x1x4x4xf32>
+  return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           memref.cast
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %dim_ext = arith.divui %dim0 , %arg1: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]

if (!CastOp::canFoldIntoConsumerOp(cast))
return failure();

auto originalOutputShape = op.getMixedOutputShape();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you spell out the type here and for newOutputShape?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


// Convert output shape dims from dynamic to static where possible.
for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
auto dimVal = dimSize.dyn_cast<Value>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notation is deprecated. Use dyn_cast<Value>(dimSize).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
auto dimVal = dimSize.dyn_cast<Value>();
if (!dimVal) {
newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getConstantIntValue works on OpFoldResult. There's no need to check whether it's a Value or Attribute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the whole loop to follow your proposal

continue;
}

auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to hard-code ConstantIndexOp here. getConstantIntValue will extract the int64_t. If it fails, you can assume it's a dynamic value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review

// -----

// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
// CHECK-NOT: memref.cast
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update all the new test cases to include the types of the memref.expand_shape? E.g., // CHECK: memref.expand_shape {{.*}} : memref<...> to memref<...>

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@kdmitry1 kdmitry1 requested a review from qedawkins December 2, 2025 12:40
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM % a few nits, make sure you get an approval from Matthias as well before landing.

// CHECK-NOT: memref.cast
// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
// CHECK-NOT: memref.cast
// CHECK: return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Here and below, instead of using CHECK-NOT lines to make sure the cast propagated/folded the way you expected, you can match the labels like

// CHECK:  %[[EXPAND:.+]] = memref.expand_shape ...
// CHECK:  return %[[EXPAND]]

which is both a strong check and doesn't need CHECK-NOT's.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is a more standard label name.

Suggested change
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
// CHECK-NOT: memref.cast
// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
// CHECK-NOT: memref.cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the CHECK-NOT's for memref.cast aren't really needed here. We already know the live IR was transformed the way we wanted so we don't really care if there is still a cast sticking around.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Got the idea. Thank you.

@kdmitry1
Copy link
Author

kdmitry1 commented Dec 7, 2025

@qedawkins, @matthias-springer, please note I've added some minor changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants