Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Move transpose with unit-dim to shape_cast pattern #72493

Conversation

c-rhodes
Copy link
Collaborator

Moved from lowering to canonicalization.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 16, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

Moved from lowering to canonicalization.


Full diff: https://github.com/llvm/llvm-project/pull/72493.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+40-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-18)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+51)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-51)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..432c11e3c449e0e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose with non-scalable unit dims into a shape_cast.
+///
+/// Replace:
+///   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+///                                 vector<1xnxelty>
+/// with:
+///   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit dims can be scalable.
+class FoldTransposeWithNonScalableUnitDimsToShapeCast final
+    : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transpOp.getVector();
+    VectorType resType = transpOp.getResultVectorType();
+
+    SmallVector<int64_t> permutation;
+    transpOp.getTransp(permutation);
+
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        permutation == ArrayRef<int64_t>({1, 0})) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
+                                                       input);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat,
+              FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
 }
 
 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..25a53b31163432e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -336,24 +336,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
     if (inputType.isScalable())
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..b3902d2d9b4dde0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,3 +2524,54 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+  return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+  return %0 : vector<[1]x4xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index c0b44428d5bcf30..72be5e4dbe3ee16 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -790,57 +790,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
-
-// CHECK-LABEL: func @transpose10_4x1xf32
-func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
-  return %0 : vector<1x4xf32>
-}
-
-// CHECK-LABEL: func @transpose10_nx4x1xf32
-func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1x4xf32
-func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
-  return %0 : vector<4x1xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1xnx4xf32
-func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
-  return %0 : vector<[4]x1xf32>
-}
-
-/// Scalable unit dim should not be lowered to shape_cast.
-
-// CHECK-LABEL: func @transpose10_4xnx1xf32
-func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
-  // CHECK-NOT: vector.shape_cast
-  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
-  return %0 : vector<[1]x4xf32>
-}
-
-// CHECK-LABEL: func @transpose10_nx4xnx1xf32
-func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
-  // CHECK-NOT: vector.shape_cast
-  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
-
-  return %0 : vector<[1]x4xf32>
-}
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! This fixes one of the issues I'm hitting right now :)

Comment on lines +5590 to +5595
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
permutation == ArrayRef<int64_t>({1, 0})) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up patch, I wonder if we could generalize this to n-D dimensions where 0 or 1 of them is != 1? If I'm not missing something, the permutation pattern itself shouldn't even matter for those cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good I'll look into it soon 👍

@c-rhodes c-rhodes merged commit 95acb33 into llvm:main Nov 17, 2023
6 checks passed
@c-rhodes c-rhodes deleted the mlir-vector-transpose-with-unit-dim-to-shape-cast-canonicalization branch November 17, 2023 14:06
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
MaheshRavishankar added a commit to MaheshRavishankar/llvm-project that referenced this pull request Nov 20, 2023
@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Nov 20, 2023

Based on #72105 (comment) I would hope that we can revert the move to canonicalization at least since that does not give any user control. The revert is in #72918

@MaheshRavishankar
Copy link
Contributor

Thanks a lot! This fixes one of the issues I'm hitting right now :)

Thats unfortunate. Maybe we should make this be a separate pattern that is added where needed. It seems like it cannot be handled everywhere.

@antiagainst
Copy link
Member

+1. I'm also missing the context of why going down vector shape cast is preferrable. Can somebody explain or give me some pointers?

Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 20, 2023
@banach-space
Copy link
Contributor

I'm sorry that this is causing issues

I'm also missing the context of why going down vector shape cast is preferrable. Can somebody explain or give me some pointers?

This is important in the context of scalable vectors:

  • LLVM can handle arrays of scalable vectors (which means that vector<1x[4]xf32> is supported), however
  • LLVM cannot deal with "scalable" arrays of vectors (which means that vector<[4]x1xf32> is not supported).

This canonicalisation merely flips the dimension to make things super easy further down the compilation stack (vector<1x[4]xf32> --> vector<[4]x1xf32>).

Maybe we should make this be a separate pattern that is added where needed.

Would you be able to make a specific suggestion? What's available to make sure that a pattern doesn't trigger for a particular target?

Rather than reverting this, I'd much prefer for us to "move" or refactor this, so that it's no longer problematic. Currently this unblocking us and solving an issue for @dcaballe , so there are benefits of keeping this in tree.

MaheshRavishankar added a commit that referenced this pull request Nov 21, 2023
@antiagainst
Copy link
Member

I think creating a populate*Patterns() to expose this pattern should work? There are lots of existing vectore related patterns in the codebase doing that. That way it gives more control for different CodeGen flows and you can explicitly call that in your flow. Things put in canonicalization is triggering everywhere as we don't have a way to control it.

Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 26, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to iree-org/llvm-project that referenced this pull request Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants