Skip to content

Conversation

@NexMing
Copy link
Contributor

@NexMing NexMing commented Oct 22, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-mlir-memref

Author: Ming Yan (NexMing)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/164585.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (-87)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+15)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
   atLeastOneReplacement |= replaceConstantUsesOf(
       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
+  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+  if (auto prev = getSource().getDefiningOp<CastOp>())
+    if (isa<MemRefType>(prev.getSource().getType())) {
+      getSourceMutable().assign(prev.getSource());
+      atLeastOneReplacement = true;
+    }
+
   return success(atLeastOneReplacement);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
-/// Replace `base, offset, sizes, strides =
-///              extract_strided_metadata(
-///                 cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-///            ? dstTy.srcOffset
-///            : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-///           !srcSize.isDynamic()
-///             ? srcSize
-//              : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-///             !srcStrides.isDynamic()
-///               ? srcStrides
-///               : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  PatternRewriter &rewriter) const override {
-    Value source = extractStridedMetadataOp.getSource();
-    auto castOp = source.getDefiningOp<memref::CastOp>();
-    if (!castOp)
-      return failure();
-
-    Location loc = extractStridedMetadataOp.getLoc();
-    // Check if the source is suitable for extract_strided_metadata.
-    SmallVector<Type> inferredReturnTypes;
-    if (failed(extractStridedMetadataOp.inferReturnTypes(
-            rewriter.getContext(), loc, {castOp.getSource()},
-            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
-            inferredReturnTypes)))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "cast source's type is incompatible");
-
-    auto memrefType = cast<MemRefType>(source.getType());
-    unsigned rank = memrefType.getRank();
-    SmallVector<OpFoldResult> results;
-    results.resize_for_overwrite(rank * 2 + 2);
-
-    auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
-        rewriter, loc, castOp.getSource());
-
-    // Register the base_buffer.
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-
-    auto getConstantOrValue = [&rewriter](int64_t constant,
-                                          OpFoldResult ofr) -> OpFoldResult {
-      return ShapedType::isStatic(constant)
-                 ? OpFoldResult(rewriter.getIndexAttr(constant))
-                 : ofr;
-    };
-
-    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
-    assert(sourceStrides.size() == rank && "unexpected number of strides");
-
-    // Register the new offset.
-    results[1] =
-        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + rank;
-    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
-    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
-    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
-    for (unsigned i = 0; i < rank; ++i) {
-      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
-      results[strideStartIdx + i] =
-          getConstantOrValue(sourceStrides[i], strides[i]);
-    }
-    rewriter.replaceOp(extractStridedMetadataOp,
-                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
-    return success();
-  }
-};
-
 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
 ///      memory_space_cast(src) to dstTy)`
 /// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[C4:.*]] = arith.constant 4 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+  %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/164585.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (-87)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+15)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
   atLeastOneReplacement |= replaceConstantUsesOf(
       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
+  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+  if (auto prev = getSource().getDefiningOp<CastOp>())
+    if (isa<MemRefType>(prev.getSource().getType())) {
+      getSourceMutable().assign(prev.getSource());
+      atLeastOneReplacement = true;
+    }
+
   return success(atLeastOneReplacement);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
-/// Replace `base, offset, sizes, strides =
-///              extract_strided_metadata(
-///                 cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-///            ? dstTy.srcOffset
-///            : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-///           !srcSize.isDynamic()
-///             ? srcSize
-//              : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-///             !srcStrides.isDynamic()
-///               ? srcStrides
-///               : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  PatternRewriter &rewriter) const override {
-    Value source = extractStridedMetadataOp.getSource();
-    auto castOp = source.getDefiningOp<memref::CastOp>();
-    if (!castOp)
-      return failure();
-
-    Location loc = extractStridedMetadataOp.getLoc();
-    // Check if the source is suitable for extract_strided_metadata.
-    SmallVector<Type> inferredReturnTypes;
-    if (failed(extractStridedMetadataOp.inferReturnTypes(
-            rewriter.getContext(), loc, {castOp.getSource()},
-            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
-            inferredReturnTypes)))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "cast source's type is incompatible");
-
-    auto memrefType = cast<MemRefType>(source.getType());
-    unsigned rank = memrefType.getRank();
-    SmallVector<OpFoldResult> results;
-    results.resize_for_overwrite(rank * 2 + 2);
-
-    auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
-        rewriter, loc, castOp.getSource());
-
-    // Register the base_buffer.
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-
-    auto getConstantOrValue = [&rewriter](int64_t constant,
-                                          OpFoldResult ofr) -> OpFoldResult {
-      return ShapedType::isStatic(constant)
-                 ? OpFoldResult(rewriter.getIndexAttr(constant))
-                 : ofr;
-    };
-
-    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
-    assert(sourceStrides.size() == rank && "unexpected number of strides");
-
-    // Register the new offset.
-    results[1] =
-        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + rank;
-    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
-    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
-    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
-    for (unsigned i = 0; i < rank; ++i) {
-      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
-      results[strideStartIdx + i] =
-          getConstantOrValue(sourceStrides[i], strides[i]);
-    }
-    rewriter.replaceOp(extractStridedMetadataOp,
-                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
-    return success();
-  }
-};
-
 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
 ///      memory_space_cast(src) to dstTy)`
 /// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[C4:.*]] = arith.constant 4 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+  %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]

@kuhar kuhar requested a review from krzysz00 October 22, 2025 11:55
@kuhar
Copy link
Member

kuhar commented Oct 22, 2025

@krzysz00 Can you take a look?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine as a folder to me, if I remember the rules for folders right

@NexMing NexMing merged commit 4eaeeab into llvm:main Oct 23, 2025
10 checks passed
@NexMing NexMing deleted the dev/fold_extract_strided_metadata branch October 23, 2025 00:54
mikolaj-pirog pushed a commit to mikolaj-pirog/llvm-project that referenced this pull request Oct 23, 2025
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants