- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[mlir][memref] Fold extract_strided_metadata(cast(x)) into extract_strided_metadata(x) #164585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rided_metadata(x)
| 
          
 @llvm/pr-subscribers-mlir-memref Author: Ming Yan (NexMing) ChangesFull diff: https://github.com/llvm/llvm-project/pull/164585.diff 3 Files Affected: 
 diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
   atLeastOneReplacement |= replaceConstantUsesOf(
       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
+  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+  if (auto prev = getSource().getDefiningOp<CastOp>())
+    if (isa<MemRefType>(prev.getSource().getType())) {
+      getSourceMutable().assign(prev.getSource());
+      atLeastOneReplacement = true;
+    }
+
   return success(atLeastOneReplacement);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
-/// Replace `base, offset, sizes, strides =
-///              extract_strided_metadata(
-///                 cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-///            ? dstTy.srcOffset
-///            : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-///           !srcSize.isDynamic()
-///             ? srcSize
-//              : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-///             !srcStrides.isDynamic()
-///               ? srcStrides
-///               : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  PatternRewriter &rewriter) const override {
-    Value source = extractStridedMetadataOp.getSource();
-    auto castOp = source.getDefiningOp<memref::CastOp>();
-    if (!castOp)
-      return failure();
-
-    Location loc = extractStridedMetadataOp.getLoc();
-    // Check if the source is suitable for extract_strided_metadata.
-    SmallVector<Type> inferredReturnTypes;
-    if (failed(extractStridedMetadataOp.inferReturnTypes(
-            rewriter.getContext(), loc, {castOp.getSource()},
-            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
-            inferredReturnTypes)))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "cast source's type is incompatible");
-
-    auto memrefType = cast<MemRefType>(source.getType());
-    unsigned rank = memrefType.getRank();
-    SmallVector<OpFoldResult> results;
-    results.resize_for_overwrite(rank * 2 + 2);
-
-    auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
-        rewriter, loc, castOp.getSource());
-
-    // Register the base_buffer.
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-
-    auto getConstantOrValue = [&rewriter](int64_t constant,
-                                          OpFoldResult ofr) -> OpFoldResult {
-      return ShapedType::isStatic(constant)
-                 ? OpFoldResult(rewriter.getIndexAttr(constant))
-                 : ofr;
-    };
-
-    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
-    assert(sourceStrides.size() == rank && "unexpected number of strides");
-
-    // Register the new offset.
-    results[1] =
-        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + rank;
-    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
-    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
-    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
-    for (unsigned i = 0; i < rank; ++i) {
-      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
-      results[strideStartIdx + i] =
-          getConstantOrValue(sourceStrides[i], strides[i]);
-    }
-    rewriter.replaceOp(extractStridedMetadataOp,
-                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
-    return success();
-  }
-};
-
 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
 ///      memory_space_cast(src) to dstTy)`
 /// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[C4:.*]] = arith.constant 4 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+  %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]
 | 
    
| 
          
 @llvm/pr-subscribers-mlir Author: Ming Yan (NexMing) ChangesFull diff: https://github.com/llvm/llvm-project/pull/164585.diff 3 Files Affected: 
 diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
   atLeastOneReplacement |= replaceConstantUsesOf(
       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
+  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+  if (auto prev = getSource().getDefiningOp<CastOp>())
+    if (isa<MemRefType>(prev.getSource().getType())) {
+      getSourceMutable().assign(prev.getSource());
+      atLeastOneReplacement = true;
+    }
+
   return success(atLeastOneReplacement);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
-/// Replace `base, offset, sizes, strides =
-///              extract_strided_metadata(
-///                 cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-///            ? dstTy.srcOffset
-///            : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-///           !srcSize.isDynamic()
-///             ? srcSize
-//              : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-///             !srcStrides.isDynamic()
-///               ? srcStrides
-///               : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  PatternRewriter &rewriter) const override {
-    Value source = extractStridedMetadataOp.getSource();
-    auto castOp = source.getDefiningOp<memref::CastOp>();
-    if (!castOp)
-      return failure();
-
-    Location loc = extractStridedMetadataOp.getLoc();
-    // Check if the source is suitable for extract_strided_metadata.
-    SmallVector<Type> inferredReturnTypes;
-    if (failed(extractStridedMetadataOp.inferReturnTypes(
-            rewriter.getContext(), loc, {castOp.getSource()},
-            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
-            inferredReturnTypes)))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "cast source's type is incompatible");
-
-    auto memrefType = cast<MemRefType>(source.getType());
-    unsigned rank = memrefType.getRank();
-    SmallVector<OpFoldResult> results;
-    results.resize_for_overwrite(rank * 2 + 2);
-
-    auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
-        rewriter, loc, castOp.getSource());
-
-    // Register the base_buffer.
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-
-    auto getConstantOrValue = [&rewriter](int64_t constant,
-                                          OpFoldResult ofr) -> OpFoldResult {
-      return ShapedType::isStatic(constant)
-                 ? OpFoldResult(rewriter.getIndexAttr(constant))
-                 : ofr;
-    };
-
-    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
-    assert(sourceStrides.size() == rank && "unexpected number of strides");
-
-    // Register the new offset.
-    results[1] =
-        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + rank;
-    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
-    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
-    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
-    for (unsigned i = 0; i < rank; ++i) {
-      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
-      results[strideStartIdx + i] =
-          getConstantOrValue(sourceStrides[i], strides[i]);
-    }
-    rewriter.replaceOp(extractStridedMetadataOp,
-                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
-    return success();
-  }
-};
-
 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
 ///      memory_space_cast(src) to dstTy)`
 /// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[C4:.*]] = arith.constant 4 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+  %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]
 | 
    
| 
           @krzysz00 Can you take a look?  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine as a folder to me, if I remember the rules for folders right
…rided_metadata(x) (llvm#164585)
…rided_metadata(x) (llvm#164585)
…rided_metadata(x) (llvm#164585)
…rided_metadata(x) (llvm#164585)
No description provided.