Skip to content

Commit e575539

Browse files
[milr][memref]: Fold expand_shape + transfer_read (#167679)
Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. --------- Signed-off-by: Jack Frankland <jack.frankland@arm.com>
1 parent 999deef commit e575539

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
347347
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
348348
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
349349
return failure();
350-
llvm::TypeSwitch<Operation *, void>(loadOp)
350+
351+
return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
351352
.Case([&](affine::AffineLoadOp op) {
352353
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
353354
loadOp, expandShapeOp.getViewSource(), sourceIndices);
355+
return success();
354356
})
355357
.Case([&](memref::LoadOp op) {
356358
rewriter.replaceOpWithNewOp<memref::LoadOp>(
357359
loadOp, expandShapeOp.getViewSource(), sourceIndices,
358360
op.getNontemporal());
361+
return success();
359362
})
360363
.Case([&](vector::LoadOp op) {
361364
rewriter.replaceOpWithNewOp<vector::LoadOp>(
362365
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
363366
op.getNontemporal());
367+
return success();
364368
})
365369
.Case([&](vector::MaskedLoadOp op) {
366370
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
367371
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
368372
op.getMask(), op.getPassThru());
373+
return success();
374+
})
375+
.Case([&](vector::TransferReadOp op) {
376+
// We only support minor identity maps in the permutation attribute.
377+
if (!op.getPermutationMap().isMinorIdentity())
378+
return failure();
379+
380+
// We only support the case where the source of the expand shape has
381+
// rank greater than or equal to the vector rank.
382+
const int64_t sourceRank = sourceIndices.size();
383+
const int64_t vectorRank = op.getVectorType().getRank();
384+
if (sourceRank < vectorRank)
385+
return failure();
386+
387+
// We need to construct a new minor identity map since we will have lost
388+
// some dimensions in folding away the expand shape.
389+
auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
390+
op.getContext());
391+
392+
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
393+
op, op.getVectorType(), expandShapeOp.getViewSource(),
394+
sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
395+
op.getInBounds());
396+
return success();
369397
})
370398
.DefaultUnreachable("unexpected operation");
371-
return success();
372399
}
373400

374401
template <typename OpTy>
@@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
659686
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
660687
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
661688
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
689+
LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
662690
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
663691
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
664692
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,55 @@ func.func @fold_vector_maskedstore_expand_shape(
992992

993993
// -----
994994

995+
func.func @fold_vector_transfer_read_expand_shape(
996+
%arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
997+
%c0 = arith.constant 0 : index
998+
%pad = ub.poison : f32
999+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
1000+
%1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32>
1001+
return %1 : vector<8xf32>
1002+
}
1003+
1004+
// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape
1005+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
1006+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
1007+
// CHECK: %[[C0:.*]] = arith.constant 0
1008+
// CHECK: %[[PAD:.*]] = ub.poison : f32
1009+
// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
1010+
// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]}
1011+
1012+
// -----
1013+
1014+
func.func @fold_vector_transfer_read_with_perm_map(
1015+
%arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
1016+
%c0 = arith.constant 0 : index
1017+
%pad = ub.poison : f32
1018+
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
1019+
%1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32>
1020+
return %1 : vector<4x4xf32>
1021+
}
1022+
1023+
// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map
1024+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
1025+
// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
1026+
1027+
// -----
1028+
1029+
func.func @fold_vector_transfer_read_rank_mismatch(
1030+
%arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
1031+
%c0 = arith.constant 0 : index
1032+
%pad = ub.poison : f32
1033+
%0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
1034+
%1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32>
1035+
return %1 : vector<4x4xf32>
1036+
}
1037+
1038+
// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch
1039+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
1040+
// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32>
1041+
1042+
// -----
1043+
9951044
func.func @fold_vector_load_collapse_shape(
9961045
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
9971046
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>

0 commit comments

Comments
 (0)