Skip to content

[mlir][vector] Take dim sizes into account in DropInnerMostUnitDims. #71752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 10, 2023

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Nov 9, 2023

The stride == 1 does not imply that we can drop it. Because it could load more than 1 elements. We should also take source sizes and vector sizes into account. Otherwise it generates invalid IRs. E.g.,

func.func @foo(%arg0: memref<1x1xf32>) -> vector<4x8xf32> {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32>
  return %0 : vector<4x8xf32>
}

Fixes iree-org/iree#15493

The `stride == 1` does not imply that we can drop it. Because it could
be out of bounds. We should also take vector sizes into account.
@llvmbot
Copy link
Member

llvmbot commented Nov 9, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

The stride == 1 does not imply that we can drop it. Because it could be out of bounds. We should also take source sizes and vector sizes into account.


Full diff: https://github.com/llvm/llvm-project/pull/71752.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 16f9d7e4d57eef7..a96d75c5804f555 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1185,9 +1185,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
       return failure();
 
     size_t dimsToDrop = 0;
-    for (size_t i = 1; i < srcStrides.size(); ++i) {
-      int dim = srcType.getRank() - i - 1;
-      if (srcStrides[dim] == 1) {
+    int rankDiff = srcType.getRank() - readOp.getVectorType().getRank();
+    for (int64_t i = 0; i < targetType.getRank(); ++i) {
+      int dim = targetType.getRank() - i - 1;
+      if (srcStrides[dim + rankDiff] == 1 &&
+          srcType.getDimSize(dim + rankDiff) == 1 &&
+          targetType.getDimSize(dim) == 1) {
         dimsToDrop++;
       } else {
         break;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index ef0bd9ddf8abf4b..1e97c482dc97ee7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -60,3 +60,17 @@ func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:inde
 //      CHECK:   %[[V:.+]] = vector.transfer_read %[[SRC_1]]
 // CHECK-SAME:       {in_bounds = [true]}
 // CHECK-SAME:       vector<4xf32>
+
+// -----
+
+func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) -> vector<4x8xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32>
+  return %0 : vector<4x8xf32>
+}
+//      CHECK: func.func @contiguous_inner_most_dim_out_of_bounds_2d
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-NOT:   memref.subview
+//      CHECK:   %[[READ:.+]] = vector.transfer_read %[[SRC]]
+//      CHECK:   return %[[READ]] : vector<4x8xf32>

@hanhanW hanhanW changed the title [mlir][vector] Take vector sizes into account in DropInnerMostUnitDims. [mlir][vector] Take dim sizes into account in DropInnerMostUnitDims. Nov 9, 2023
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few comments that you can address before submitting.

@hanhanW hanhanW merged commit 2bac720 into llvm:main Nov 10, 2023
@hanhanW hanhanW deleted the fix-iree-15493 branch November 10, 2023 17:28
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…lvm#71752)

The `stride == 1` does not imply that we can drop it. Because it could
load more than 1 elements. We should also take source sizes and vector
sizes into account. Otherwise it generates invalid IRs. E.g.,

```mlir
func.func @foo(%arg0: memref<1x1xf32>) -> vector<4x8xf32> {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32>
  return %0 : vector<4x8xf32>
}
```

Fixes iree-org/iree#15493
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Invalid vector.shape_cast op is generated by DropInnerMostUnitDims pattern
3 participants