-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Take dim sizes into account in DropInnerMostUnitDims. #71752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `stride == 1` does not imply that we can drop it. Because it could be out of bounds. We should also take vector sizes into account.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/71752.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 16f9d7e4d57eef7..a96d75c5804f555 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1185,9 +1185,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
return failure();
size_t dimsToDrop = 0;
- for (size_t i = 1; i < srcStrides.size(); ++i) {
- int dim = srcType.getRank() - i - 1;
- if (srcStrides[dim] == 1) {
+ int rankDiff = srcType.getRank() - readOp.getVectorType().getRank();
+ for (int64_t i = 0; i < targetType.getRank(); ++i) {
+ int dim = targetType.getRank() - i - 1;
+ if (srcStrides[dim + rankDiff] == 1 &&
+ srcType.getDimSize(dim + rankDiff) == 1 &&
+ targetType.getDimSize(dim) == 1) {
dimsToDrop++;
} else {
break;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index ef0bd9ddf8abf4b..1e97c482dc97ee7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -60,3 +60,17 @@ func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:inde
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_1]]
// CHECK-SAME: {in_bounds = [true]}
// CHECK-SAME: vector<4xf32>
+
+// -----
+
+func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) -> vector<4x8xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32>
+ return %0 : vector<4x8xf32>
+}
+// CHECK: func.func @contiguous_inner_most_dim_out_of_bounds_2d
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-NOT: memref.subview
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
+// CHECK: return %[[READ]] : vector<4x8xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few comments that you can address before submitting.
…lvm#71752) The `stride == 1` does not imply that we can drop it. Because it could load more than 1 elements. We should also take source sizes and vector sizes into account. Otherwise it generates invalid IRs. E.g., ```mlir func.func @foo(%arg0: memref<1x1xf32>) -> vector<4x8xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32> return %0 : vector<4x8xf32> } ``` Fixes iree-org/iree#15493
The
stride == 1
does not imply that we can drop it. Because it could load more than 1 elements. We should also take source sizes and vector sizes into account. Otherwise it generates invalid IRs. E.g.,Fixes iree-org/iree#15493