-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern #72105
[mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern #72105
Conversation
…t pattern This patch extends the vector.transpose lowering to replace: vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>> with: vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>> Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable. A check is also added to bail out for scalable vectors before unrolling.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Cullen Rhodes (c-rhodes) ChangesThis patch extends the vector.transpose lowering to replace: vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>> with: vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>> Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable. A check is also added to bail out for scalable vectors before unrolling. Full diff: https://github.com/llvm/llvm-project/pull/72105.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7d804ddcfa42ffe..cf35d64c0c6268d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
+ // Replace:
+ // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+ // vector<1xnxelty>
+ // with:
+ // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+ //
+ // Source with leading unit dim (inverse) is also replaced. Unit dim must
+ // be fixed. Non-unit can be scalable.
+ if (resType.getRank() == 2 &&
+ ((resType.getShape().front() == 1 &&
+ !resType.getScalableDims().front()) ||
+ (resType.getShape().back() == 1 &&
+ !resType.getScalableDims().back())) &&
+ transp[0] == 1 && transp[1] == 0) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ return success();
+ }
+
+ if (inputType.isScalable())
+ return failure();
+
// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 22d9224838c49c4..c0b44428d5bcf30 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
return %0 : vector<1x1x8x8xf32>
}
+/// Scalable dim should not be unrolled.
+
+// CHECK-LABEL: func @transpose23_scalable
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.insert
+// CHECK: vector.transpose
+func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32>
+ return %0 : vector<[3]x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
@@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4x1xf32
+func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4x1xf32
+func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1x4xf32
+func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+ return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1xnx4xf32
+func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+ return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4xnx1xf32
+func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+ return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4xnx1xf32
+func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+ return %0 : vector<[1]x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_transpose
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
|
Could we make this a transpose canonicalization pattern instead? Getting redundant transposes around before its lowering is not helpful. We are actually hitting an issue related to 1xn to nx1 transposes and would need that type of canonicalization. |
I've posted a PR #72493 to move this to canonicalization |
…t pattern (llvm#72105) This patch extends the vector.transpose lowering to replace: vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>> with: vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>> Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable. A check is also added to bail out for scalable vectors before unrolling.
This PR broke downstream on conversion to SPIR-V. Also looking at this again, I am not sure lowering this to a cc @antiagainst @qedawkins and @kuhar |
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
Sorry to hear it's causing issues. Just looking at the docs for Please could you provide any more info? There's nothing here that explains what the problem with the lowering is. |
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
…hape_cast pattern (llvm#72105)" This reverts commit b7b6d54.
Following the discussion here: * llvm#72105 this patch makes the `TransposeOpLowering` configurable so that one can select whether to favour `vector.shape_cast` over `vector.transpose`. As per the discussion in llvm#72105, using `vector.shape_cast` is very beneficial and desirable when targeting `LLVM IR` (CPU lowering), but simply won't work when targeting `SPIR-V` (GPU lowering). So we need a mechanism to be able to disable/enable the pattern introduced in llvm#72105. This patch proposes one such mechanism. While this should solve the problem that we are facing today, we may need to introduce something more elaborate to specialise for CPU vs GPU lowering. Also, (once implemented) this proposal might make this workaround redundant: * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
Following the discussion here: * #72105 this patch makes the `TransposeOpLowering` configurable so that one can select whether to favour `vector.shape_cast` over `vector.transpose`. As per the discussion in #72105, using `vector.shape_cast` is very beneficial and desirable when targeting `LLVM IR` (CPU lowering), but won't work when targeting `SPIR-V` today (GPU lowering). Hence the need for a mechanism to be able to disable/enable the pattern introduced in #72105. This patch proposes one such mechanism. While this should solve the problem that we are facing today, it's understood to be a temporary workaround. It should be removed once support for lowering `vector.shape_cast` to SPIR-V is added. Also, (once implemented) the following proposal might make this workaround redundant: * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
…73915)" Reverting a workaround intended specifically for SPRI-V. That workaround emerged from this discussion: * llvm#72105 AFAIK, it hasn't been required in practice. This is based on IREE (https://github.com/openxla/iree), which has just bumped it's fork of LLVM without using it (*). (*) iree-org/iree@cef31e7 This reverts commit bbd2b08.
…" (#75062) Reverting a workaround intended specifically for SPRI-V. That workaround emerged from this discussion: * #72105 AFAIK, it hasn't been required in practice. This is based on IREE (https://github.com/openxla/iree), which has just bumped it's fork of LLVM without using it (*). (*) iree-org/iree@cef31e7 This reverts commit bbd2b08.
This patch extends the vector.transpose lowering to replace:
vector.transpose %0, [1, 0] : vector<nx1x> to vector<1xnx>
with:
vector.shape_cast %0 : vector<nx1x> to vector<1xnx>
Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable.
A check is also added to bail out for scalable vectors before unrolling.