From 41d948a2769bbb2091a77832751c3694eebc006f Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Wed, 12 Nov 2025 10:57:10 +0000 Subject: [PATCH 1/2] [milr][memref]: Fold expand_shape + transfer_read Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. Signed-off-by: Jack Frankland --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 26 ++++++++++++-- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 34 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 214410f78e51c..30df10c1deedc 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -347,28 +347,49 @@ LogicalResult LoadOpOfExpandShapeOpFolder::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, isa(loadOp.getOperation())))) return failure(); - llvm::TypeSwitch(loadOp) + + return llvm::TypeSwitch(loadOp) .Case([&](affine::AffineLoadOp op) { rewriter.replaceOpWithNewOp( loadOp, expandShapeOp.getViewSource(), sourceIndices); + return success(); }) .Case([&](memref::LoadOp op) { rewriter.replaceOpWithNewOp( loadOp, expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::LoadOp op) { rewriter.replaceOpWithNewOp( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::MaskedLoadOp op) { rewriter.replaceOpWithNewOp( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); + return success(); + }) + .Case([&](vector::TransferReadOp op) { + // We only support minor identity maps in the permutation attribute. + if (!op.getPermutationMap().isMinorIdentity()) + return failure(); + + // We need to construct a new minor identity map since we will have lost + // some dimensions in folding away the expand shape. + auto minorIdMap = AffineMap::getMinorIdentityMap( + sourceIndices.size(), op.getVectorType().getRank(), + op.getContext()); + + rewriter.replaceOpWithNewOp( + op, op.getVectorType(), expandShapeOp.getViewSource(), + sourceIndices, minorIdMap, op.getPadding(), op.getMask(), + op.getInBounds()); + return success(); }) .DefaultUnreachable("unexpected operation"); - return success(); } template @@ -659,6 +680,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfExpandShapeOpFolder, LoadOpOfExpandShapeOpFolder, LoadOpOfExpandShapeOpFolder, + LoadOpOfExpandShapeOpFolder, StoreOpOfExpandShapeOpFolder, StoreOpOfExpandShapeOpFolder, StoreOpOfExpandShapeOpFolder, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 106652623933f..87f23457644ae 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -992,6 +992,40 @@ func.func @fold_vector_maskedstore_expand_shape( // ----- +func.func @fold_vector_transfer_read_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[PAD:.*]] = ub.poison : f32 +// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8) +// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]} + +// ----- + +func.func @fold_vector_transfer_read_with_perm_map( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + +// ----- + func.func @fold_vector_load_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> From 8ae2e8de6dffa016bfb94d0df55a9779665266ba Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 24 Nov 2025 11:51:08 +0000 Subject: [PATCH 2/2] [mlir][memref]: Add Check and Negative Test Add a missing check and negative test. Signed-off-by: Jack Frankland --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 12 +++++++++--- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 30df10c1deedc..3667fdb2bb728 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -377,11 +377,17 @@ LogicalResult LoadOpOfExpandShapeOpFolder::matchAndRewrite( if (!op.getPermutationMap().isMinorIdentity()) return failure(); + // We only support the case where the source of the expand shape has + // rank greater than or equal to the vector rank. + const int64_t sourceRank = sourceIndices.size(); + const int64_t vectorRank = op.getVectorType().getRank(); + if (sourceRank < vectorRank) + return failure(); + // We need to construct a new minor identity map since we will have lost // some dimensions in folding away the expand shape. - auto minorIdMap = AffineMap::getMinorIdentityMap( - sourceIndices.size(), op.getVectorType().getRank(), - op.getContext()); + auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank, + op.getContext()); rewriter.replaceOpWithNewOp( op, op.getVectorType(), expandShapeOp.getViewSource(), diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 87f23457644ae..ca91b0141f593 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -1026,6 +1026,21 @@ func.func @fold_vector_transfer_read_with_perm_map( // ----- +func.func @fold_vector_transfer_read_rank_mismatch( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + %1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + +// ----- + func.func @fold_vector_load_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>