Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes. #79752

Merged
merged 1 commit into from
Jan 29, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jan 28, 2024

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 28, 2024

@llvm/pr-subscribers-mlir-vector

Author: Han-Chung Wang (hanhanW)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/79752.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+17-12)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+40)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 12aa11e9e33f5e7..8363e73857e5c54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
-    if (!srcType || !srcType.hasStaticShape())
+    if (!srcType)
       return failure();
 
     if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1260,19 +1260,21 @@ class DropInnerMostUnitDimsTransferRead
                         targetType.getElementType());
 
     auto loc = readOp.getLoc();
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, readOp.getSource());
+    SmallVector<OpFoldResult> offsets(srcType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(srcType.getRank(),
+                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
-    SmallVector<int64_t> offsets(srcType.getRank(), 0);
-    SmallVector<int64_t> strides(srcType.getRank(), 1);
-
     ArrayAttr inBoundsAttr =
         readOp.getInBounds()
             ? rewriter.getArrayAttr(
                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
-        strides);
+        loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
     Value result = rewriter.create<vector::TransferReadOp>(
@@ -1318,7 +1320,7 @@ class DropInnerMostUnitDimsTransferWrite
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
-    if (!srcType || !srcType.hasStaticShape())
+    if (!srcType)
       return failure();
 
     if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -1341,20 +1343,23 @@ class DropInnerMostUnitDimsTransferWrite
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
+    Location loc = writeOp.getLoc();
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, writeOp.getSource());
+    SmallVector<OpFoldResult> offsets(srcType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(srcType.getRank(),
+                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
-    SmallVector<int64_t> offsets(srcType.getRank(), 0);
-    SmallVector<int64_t> strides(srcType.getRank(), 1);
     ArrayAttr inBoundsAttr =
         writeOp.getInBounds()
             ? rewriter.getArrayAttr(
                   writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
 
-    Location loc = writeOp.getLoc();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
-        strides);
+        loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index d6d69c8af88508d..3984f17f9e8cdb7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
 
 // -----
 
+func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+  return %0 : vector<1x8x1xf32>
+}
+//      CHECK: func @contiguous_outer_dyn_inner_most_view(
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
+//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
+// CHECK-SAME:    memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
+// CHECK-SAME:    memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0.0 : f32
@@ -119,6 +138,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
 
 // -----
 
+func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
+    {in_bounds = [true, true, true, true]}
+    : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+  return
+}
+// CHECK:      func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
+// CHECK-SAME:     memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[IDX]], %[[C0]], %[[C0]]]
+
+// -----
+
 func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 28, 2024

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/79752.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+17-12)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+40)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 12aa11e9e33f5e..8363e73857e5c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
-    if (!srcType || !srcType.hasStaticShape())
+    if (!srcType)
       return failure();
 
     if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1260,19 +1260,21 @@ class DropInnerMostUnitDimsTransferRead
                         targetType.getElementType());
 
     auto loc = readOp.getLoc();
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, readOp.getSource());
+    SmallVector<OpFoldResult> offsets(srcType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(srcType.getRank(),
+                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
-    SmallVector<int64_t> offsets(srcType.getRank(), 0);
-    SmallVector<int64_t> strides(srcType.getRank(), 1);
-
     ArrayAttr inBoundsAttr =
         readOp.getInBounds()
             ? rewriter.getArrayAttr(
                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
-        strides);
+        loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
     Value result = rewriter.create<vector::TransferReadOp>(
@@ -1318,7 +1320,7 @@ class DropInnerMostUnitDimsTransferWrite
       return failure();
 
     auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
-    if (!srcType || !srcType.hasStaticShape())
+    if (!srcType)
       return failure();
 
     if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -1341,20 +1343,23 @@ class DropInnerMostUnitDimsTransferWrite
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
+    Location loc = writeOp.getLoc();
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, writeOp.getSource());
+    SmallVector<OpFoldResult> offsets(srcType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(srcType.getRank(),
+                                      rewriter.getIndexAttr(1));
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
-    SmallVector<int64_t> offsets(srcType.getRank(), 0);
-    SmallVector<int64_t> strides(srcType.getRank(), 1);
     ArrayAttr inBoundsAttr =
         writeOp.getInBounds()
             ? rewriter.getArrayAttr(
                   writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
 
-    Location loc = writeOp.getLoc();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
-        strides);
+        loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
         cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index d6d69c8af88508..3984f17f9e8cdb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
 
 // -----
 
+func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
+  return %0 : vector<1x8x1xf32>
+}
+//      CHECK: func @contiguous_outer_dyn_inner_most_view(
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
+//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
+// CHECK-SAME:    memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
+// CHECK-SAME:    memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
+//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0.0 : f32
@@ -119,6 +138,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
 
 // -----
 
+func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
+    {in_bounds = [true, true, true, true]}
+    : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+  return
+}
+// CHECK:      func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
+// CHECK-SAME:     memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[IDX]], %[[C0]], %[[C0]]]
+
+// -----
+
 func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]

@hanhanW hanhanW merged commit 66347e5 into llvm:main Jan 29, 2024
6 of 7 checks passed
@hanhanW hanhanW deleted the transfer-write-inner-unit-dims-dynamic branch January 29, 2024 08:30
dcaballe added a commit to iree-org/llvm-project that referenced this pull request Feb 4, 2024
hanhanW added a commit that referenced this pull request Feb 5, 2024
…ic shapes." (#80712)

Reverts #79752 because it is causing regressions in
downstream projects.
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…ic shapes." (llvm#80712)

Reverts llvm#79752 because it is causing regressions in
downstream projects.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants