Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern #72105

Merged

Conversation

c-rhodes
Copy link
Collaborator

This patch extends the vector.transpose lowering to replace:

vector.transpose %0, [1, 0] : vector<nx1x> to vector<1xnx>

with:

vector.shape_cast %0 : vector<nx1x> to vector<1xnx>

Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.

…t pattern

This patch extends the vector.transpose lowering to replace:

  vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

  vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must
be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 13, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Cullen Rhodes (c-rhodes)

Changes

This patch extends the vector.transpose lowering to replace:

vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.


Full diff: https://github.com/llvm/llvm-project/pull/72105.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+21)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+71)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7d804ddcfa42ffe..cf35d64c0c6268d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
+    // Replace:
+    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+    //                                 vector<1xnxelty>
+    // with:
+    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+    //
+    // Source with leading unit dim (inverse) is also replaced. Unit dim must
+    // be fixed. Non-unit can be scalable.
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        transp[0] == 1 && transp[1] == 0) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+      return success();
+    }
+
+    if (inputType.isScalable())
+      return failure();
+
     // Handle a true 2-D matrix transpose differently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 22d9224838c49c4..c0b44428d5bcf30 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
   return %0 : vector<1x1x8x8xf32>
 }
 
+/// Scalable dim should not be unrolled.
+
+// CHECK-LABEL: func @transpose23_scalable
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.insert
+// CHECK: vector.transpose
+func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32>
+  return %0 : vector<[3]x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {
@@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4x1xf32
+func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4x1xf32
+func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1x4xf32
+func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1xnx4xf32
+func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4xnx1xf32
+func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+  return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4xnx1xf32
+func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+  return %0 : vector<[1]x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_transpose
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}

@c-rhodes c-rhodes merged commit b7b6d54 into llvm:main Nov 15, 2023
3 checks passed
@c-rhodes c-rhodes deleted the mlir-vector-transpose-with-unit-dim-to-shape-cast branch November 15, 2023 14:14
@dcaballe
Copy link
Contributor

Could we make this a transpose canonicalization pattern instead? Getting redundant transposes around before its lowering is not helpful. We are actually hitting an issue related to 1xn to nx1 transposes and would need that type of canonicalization.

@c-rhodes
Copy link
Collaborator Author

Could we make this a transpose canonicalization pattern instead? Getting redundant transposes around before its lowering is not helpful. We are actually hitting an issue related to 1xn to nx1 transposes and would need that type of canonicalization.

I've posted a PR #72493 to move this to canonicalization

zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…t pattern (llvm#72105)

This patch extends the vector.transpose lowering to replace:

vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>>

with:

  vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>>

Source with leading unit-dim (inverse) is also replaced. Unit dim must
be fixed. Non-unit dim can be scalable.

A check is also added to bail out for scalable vectors before unrolling.
@MaheshRavishankar
Copy link
Contributor

This PR broke downstream on conversion to SPIR-V. Also looking at this again, I am not sure lowering this to a vector.shape_cast is a good idea (sorry I didnt review it as I saw it go by). AFAIK, shape_casts are meant to be only for n-D to 1D/ 1D to N-D conversion. So lowering the transpose to a shape-cast here seems to violate that and is causing issues when lowering to SPIR-V specifically.
This might need more discussion, but in the meantime can I suggest we revert this (and the follow up that moved it into a canonicalization, which is a bit more problematic in my view cause that is done without any user control).

cc @antiagainst @qedawkins and @kuhar

Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
@c-rhodes
Copy link
Collaborator Author

This PR broke downstream on conversion to SPIR-V. Also looking at this again, I am not sure lowering this to a vector.shape_cast is a good idea (sorry I didnt review it as I saw it go by). AFAIK, shape_casts are meant to be only for n-D to 1D/ 1D to N-D conversion. So lowering the transpose to a shape-cast here seems to violate that and is causing issues when lowering to SPIR-V specifically. This might need more discussion, but in the meantime can I suggest we revert this (and the follow up that moved it into a canonicalization, which is a bit more problematic in my view cause that is done without any user control).

Sorry to hear it's causing issues. Just looking at the docs for vector.shape_cast and although it only mentions rank-reducing / rank-expanding, there's no mention that shape casting between same rank vectors is not allowed?

Please could you provide any more info? There's nothing here that explains what the problem with the lowering is.

Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 20, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
Groverkss added a commit to shark-infra/llvm-project that referenced this pull request Nov 27, 2023
banach-space added a commit to banach-space/llvm-project that referenced this pull request Nov 30, 2023
Following the discussion here:
  * llvm#72105
this patch makes the `TransposeOpLowering` configurable so that one can
select whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in llvm#72105, using `vector.shape_cast` is very
beneficial and desirable when targeting `LLVM IR` (CPU lowering), but
simply won't work when targeting `SPIR-V` (GPU lowering). So we need a
mechanism to be able to disable/enable the pattern introduced in llvm#72105.
This patch proposes one such mechanism.

While this should solve the problem that we are facing today, we may
need to introduce something more elaborate to specialise for CPU vs GPU
lowering. Also, (once implemented) this proposal might make this
workaround redundant:
  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
banach-space added a commit that referenced this pull request Dec 4, 2023
Following the discussion here:

  * #72105

this patch makes the `TransposeOpLowering` configurable so that one can select
whether to favour `vector.shape_cast` over `vector.transpose`.

As per the discussion in #72105, using `vector.shape_cast` is very beneficial
and desirable when targeting `LLVM IR` (CPU lowering), but won't work when
targeting `SPIR-V` today (GPU lowering). Hence the need for a mechanism to be
able to disable/enable the pattern introduced in #72105. This patch proposes one
such mechanism.

While this should solve the problem that we are facing today, it's understood to
be a temporary workaround. It should be removed once support for lowering
`vector.shape_cast` to SPIR-V is added. Also, (once implemented) the following
proposal might make this workaround redundant:

  * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
banach-space added a commit to banach-space/llvm-project that referenced this pull request Dec 11, 2023
…73915)"

Reverting a workaround intended specifically for SPRI-V. That workaround
emerged from this discussion:

  * llvm#72105

AFAIK, it hasn't been required in practice. This is based on IREE
(https://github.com/openxla/iree), which has just bumped it's fork of LLVM
without using it (*).

(*) iree-org/iree@cef31e7

This reverts commit bbd2b08.
banach-space added a commit that referenced this pull request Dec 11, 2023
…" (#75062)

Reverting a workaround intended specifically for SPRI-V. That workaround
emerged from this discussion:

  * #72105

AFAIK, it hasn't been required in practice. This is based on IREE
(https://github.com/openxla/iree), which has just bumped it's fork of
LLVM without using it (*).

(*) iree-org/iree@cef31e7

This reverts commit bbd2b08.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants