-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[milr][memref]: Fold expand_shape + transfer_read #167679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Jack Frankland (FranklandJack) ChangesExtend the load of a expand shape rewrite pattern to support folding a Full diff: https://github.com/llvm/llvm-project/pull/167679.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 214410f78e51c..30df10c1deedc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -347,28 +347,49 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
- llvm::TypeSwitch<Operation *, void>(loadOp)
+
+ return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
.Case([&](affine::AffineLoadOp op) {
rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
+ return success();
})
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
+ return success();
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
+ return success();
+ })
+ .Case([&](vector::TransferReadOp op) {
+ // We only support minor identity maps in the permutation attribute.
+ if (!op.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ // We need to construct a new minor identity map since we will have lost
+ // some dimensions in folding away the expand shape.
+ auto minorIdMap = AffineMap::getMinorIdentityMap(
+ sourceIndices.size(), op.getVectorType().getRank(),
+ op.getContext());
+
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ op, op.getVectorType(), expandShapeOp.getViewSource(),
+ sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
+ op.getInBounds());
+ return success();
})
.DefaultUnreachable("unexpected operation");
- return success();
}
template <typename OpTy>
@@ -659,6 +680,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 106652623933f..87f23457644ae 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -992,6 +992,40 @@ func.func @fold_vector_maskedstore_expand_shape(
// -----
+func.func @fold_vector_transfer_read_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[PAD:.*]] = ub.poison : f32
+// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]}
+
+// -----
+
+func.func @fold_vector_transfer_read_with_perm_map(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f32
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+
+// -----
+
func.func @fold_vector_load_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
|
e921f18 to
ffbb4ac
Compare
| // We need to construct a new minor identity map since we will have lost | ||
| // some dimensions in folding away the expand shape. | ||
| auto minorIdMap = AffineMap::getMinorIdentityMap( | ||
| sourceIndices.size(), op.getVectorType().getRank(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to check that the vector rank is compatible with the base rank after folding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I've added a check and a negative test.
Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Add a missing check and negative test. Signed-off-by: Jack Frankland <jack.frankland@arm.com>
ffbb4ac to
8ae2e8d
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/21362 Here is the relevant piece of the build log for the reference |
Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. --------- Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. --------- Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Extend the load of a expand shape rewrite pattern to support folding a
memref.expand_shapeandvector.transfer_readwhen the permutation map onvector.transfer_readis a minor identity.