From b7e34f84fd8e22082eba19b7c59f8e40eeaeb94b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 12 Nov 2025 02:28:48 +0100 Subject: [PATCH 1/2] [mlir][memref] Remove invalid `extract_aligned_pointer_as_index` folding in `ExpandStridedMetadata` ViewLikeOpInterface by itself doesn't guarantee to preserve the base pointer in general and `memref.view` is one such example, so limit folder to a few specific ops. --- .../MemRef/Transforms/ExpandStridedMetadata.cpp | 6 +++++- .../Dialect/MemRef/expand-strided-metadata.mlir | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index bd02516d5b527..c9352e8f700d7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp(); - if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) + // ViewLikeOpInterface by itself doesn't guarantee to preserve the base + // pointer in general and `memref.view` is one such example, so just check + // for a few specific cases. + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() || + !isa(viewLikeOp)) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 18cdfb73f6ba8..ec17282ed20ec 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2 // CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base // CHECK-NOT: memref.memory_space_cast + +// ----- + +func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref) -> index { + // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer + // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer + // CHECK-SAME: (%[[ARG0:.*]]: memref) + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C0]]][] : memref to memref + // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref -> index + // CHECK: return %[[PTR]] : index + + %c0 = arith.constant 0 : index + %0 = memref.view %arg0[%c0][] : memref to memref + %1 = memref.extract_aligned_pointer_as_index %0: memref -> index + return %1 : index +} From 5dfb7a5244fabbb4a489b1747233dce6c788c04f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 12 Nov 2025 20:13:07 +0100 Subject: [PATCH 2/2] update test --- mlir/test/Dialect/MemRef/expand-strided-metadata.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index ec17282ed20ec..4ed8d4b20229f 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1462,13 +1462,13 @@ func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref) -> // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer // CHECK-SAME: (%[[ARG0:.*]]: memref) - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C0]]][] : memref to memref + // CHECK: %[[C10:.*]] = arith.constant 10 : index + // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref to memref // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref -> index // CHECK: return %[[PTR]] : index - %c0 = arith.constant 0 : index - %0 = memref.view %arg0[%c0][] : memref to memref + %c10 = arith.constant 10 : index + %0 = memref.view %arg0[%c10][] : memref to memref %1 = memref.extract_aligned_pointer_as_index %0: memref -> index return %1 : index }