-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Move transpose with unit-dim to shape_cast pattern #72493
[mlir][vector] Move transpose with unit-dim to shape_cast pattern #72493
Conversation
Moved from lowering to canonicalization.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) ChangesMoved from lowering to canonicalization. Full diff: https://github.com/llvm/llvm-project/pull/72493.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..432c11e3c449e0e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose with non-scalable unit dims into a shape_cast.
+///
+/// Replace:
+/// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+/// vector<1xnxelty>
+/// with:
+/// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit dims can be scalable.
+class FoldTransposeWithNonScalableUnitDimsToShapeCast final
+ : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
+ PatternRewriter &rewriter) const override {
+ Value input = transpOp.getVector();
+ VectorType resType = transpOp.getResultVectorType();
+
+ SmallVector<int64_t> permutation;
+ transpOp.getTransp(permutation);
+
+ if (resType.getRank() == 2 &&
+ ((resType.getShape().front() == 1 &&
+ !resType.getScalableDims().front()) ||
+ (resType.getShape().back() == 1 &&
+ !resType.getScalableDims().back())) &&
+ permutation == ArrayRef<int64_t>({1, 0})) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
+ input);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ TransposeFolder, FoldTransposeSplat,
+ FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
}
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..25a53b31163432e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -336,24 +336,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
- // Replace:
- // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
- // vector<1xnxelty>
- // with:
- // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
- //
- // Source with leading unit dim (inverse) is also replaced. Unit dim must
- // be fixed. Non-unit can be scalable.
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
if (inputType.isScalable())
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..b3902d2d9b4dde0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,3 +2524,54 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+ return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+ return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+ return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
+func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+ return %0 : vector<[1]x4xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index c0b44428d5bcf30..72be5e4dbe3ee16 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -790,57 +790,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// -----
-
-/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
-
-// CHECK-LABEL: func @transpose10_4x1xf32
-func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
- // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
- return %0 : vector<1x4xf32>
-}
-
-// CHECK-LABEL: func @transpose10_nx4x1xf32
-func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1x4xf32
-func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
- // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
- return %0 : vector<4x1xf32>
-}
-
-// CHECK-LABEL: func @transpose10_1xnx4xf32
-func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
- // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
- return %0 : vector<[4]x1xf32>
-}
-
-/// Scalable unit dim should not be lowered to shape_cast.
-
-// CHECK-LABEL: func @transpose10_4xnx1xf32
-func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
- // CHECK-NOT: vector.shape_cast
- // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
- return %0 : vector<[1]x4xf32>
-}
-
-// CHECK-LABEL: func @transpose10_nx4xnx1xf32
-func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
- // CHECK-NOT: vector.shape_cast
- // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
- %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
-
- return %0 : vector<[1]x4xf32>
-}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! This fixes one of the issues I'm hitting right now :)
if (resType.getRank() == 2 && | ||
((resType.getShape().front() == 1 && | ||
!resType.getScalableDims().front()) || | ||
(resType.getShape().back() == 1 && | ||
!resType.getScalableDims().back())) && | ||
permutation == ArrayRef<int64_t>({1, 0})) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow-up patch, I wonder if we could generalize this to n-D dimensions where 0 or 1 of them is != 1? If I'm not missing something, the permutation pattern itself shouldn't even matter for those cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good I'll look into it soon 👍
…vm#72493) Moved from lowering to canonicalization.
…vm#72493) Moved from lowering to canonicalization.
…tern (llvm#72493)" This reverts commit 95acb33.
Based on #72105 (comment) I would hope that we can revert the move to canonicalization at least since that does not give any user control. The revert is in #72918 |
Thats unfortunate. Maybe we should make this be a separate pattern that is added where needed. It seems like it cannot be handled everywhere. |
+1. I'm also missing the context of why going down vector shape cast is preferrable. Can somebody explain or give me some pointers? |
…tern (llvm#72493)" This reverts commit 95acb33.
I'm sorry that this is causing issues
This is important in the context of scalable vectors:
This canonicalisation merely flips the dimension to make things super easy further down the compilation stack (
Would you be able to make a specific suggestion? What's available to make sure that a pattern doesn't trigger for a particular target? Rather than reverting this, I'd much prefer for us to "move" or refactor this, so that it's no longer problematic. Currently this unblocking us and solving an issue for @dcaballe , so there are benefits of keeping this in tree. |
I think creating a |
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
…tern (llvm#72493)" This reverts commit 95acb33.
Moved from lowering to canonicalization.