diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index bd02516d5b527..c9352e8f700d7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp(); - if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) + // ViewLikeOpInterface by itself doesn't guarantee to preserve the base + // pointer in general and `memref.view` is one such example, so just check + // for a few specific cases. + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() || + !isa(viewLikeOp)) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 18cdfb73f6ba8..4ed8d4b20229f 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2 // CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base // CHECK-NOT: memref.memory_space_cast + +// ----- + +func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref) -> index { + // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer + // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer + // CHECK-SAME: (%[[ARG0:.*]]: memref) + // CHECK: %[[C10:.*]] = arith.constant 10 : index + // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref to memref + // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref -> index + // CHECK: return %[[PTR]] : index + + %c10 = arith.constant 10 : index + %0 = memref.view %arg0[%c10][] : memref to memref + %1 = memref.extract_aligned_pointer_as_index %0: memref -> index + return %1 : index +}