-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][vector] Add support for unrolling vector.bitcast ops. #94064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The revision unrolls vector.bitcast like: ```mlir %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> ``` to ```mlir %cst = arith.constant dense<0> : vector<2x2xi64> %0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32> %1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64> %2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64> %3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32> %4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64> %5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64> ``` The scalable vector is not supported because of the limitation of `vector::createUnrollIterator`. The targetRank could mismatch the final rank during unrolling; there is no direct way to query what the final rank is from the object.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThe revision unrolls vector.bitcast like: %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> to %cst = arith.constant dense<0> : vector<2x2xi64>
%0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32>
%1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64>
%2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64>
%3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32>
%4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64>
%5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64> The scalable vector is not supported because of the limitation of Full diff: https://github.com/llvm/llvm-project/pull/94064.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..c91e8fbbae90f 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.lower_bitcast",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector bitcast operations should be lowered to
+ finer-grained vector primitives.
+
+ This is usally a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_broadcast",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8fd9904fabc0e..1976b8399c7f9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populates the pattern set with the following patterns:
+///
+/// [UnrollBitCastOp]
+/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
+/// BitCastOp (of `targetRank`) + InsertOp.
+void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
+ int64_t targetRank = 1,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a436c4a9400..55143d5939ba2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..23960269095e5 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorBitCastLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorBroadcastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4dbefdd376a8b..723b2f62d65d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
+ LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
new file mode 100644
index 0000000000000..581ee54fb2935
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -0,0 +1,94 @@
+//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.bitcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-bitcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// A one-shot unrolling of vector.bitcast to the `targetRank`.
+///
+/// Example:
+///
+/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
+///
+/// Would be unrolled to:
+///
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %0 = vector.extract %a[0, 0, 0] ─┐
+/// : vector<4xi64> from vector<1x2x3x4xi64> |
+/// %1 = vector.bitcast %0 | - Repeated 6x for
+/// : vector<4xi64> to vector<8xi32> | all leading positions
+/// %2 = vector.insert %1, %result [0, 0, 0] |
+/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
+public:
+ UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank){};
+
+ LogicalResult matchAndRewrite(vector::BitCastOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ // TODO: Support the scalable vector cases. It is not supported because
+ // the final rank could be values other than `targetRank`. It makes creating
+ // the result type of new vector.bitcast ops much harder.
+ if (resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "unrolling vector.bitcast on scalable vectors is NIY");
+
+ SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank));
+ auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+
+ Location loc = op.getLoc();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ for (auto position : *unrollIterator) {
+ Value extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ Value bitcast =
+ rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+ result =
+ rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
+
+} // namespace
+
+void mlir::vector::populateVectorBitCastLoweringPatterns(
+ RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
+ patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..12121ea0dd70e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK: llvm.bitcast
+// CHECK-NOT: vector.bitcast
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+ %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+ return %0 : vector<2x2xi64>
+}
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
new file mode 100644
index 0000000000000..e8c529dcacc75
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+ %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+ return %0 : vector<2x2xi64>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
+// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
+// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
+// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
+// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
+// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
+// CHECK: return %[[R2]]
+
+func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
+ %0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+ return %0 : vector<1x2x[3]x8xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
+// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_bitcast
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
// TODO: Support the scalable vector cases. It is not supported because | ||
// the final rank could be values other than `targetRank`. It makes creating | ||
// the result type of new vector.bitcast ops much harder. | ||
if (resultType.isScalable()) | ||
return rewriter.notifyMatchFailure( | ||
op, "unrolling vector.bitcast on scalable vectors is NIY"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've created #94197 to make this easier 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIY?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet implemented (I guess)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it means that is not implemented yet. I learned it from some linalg transforms comments. But yeah, I agree that not yet implemented
is clearer.. I'll update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! % nit
I don't mind if we wait for #94197 and handle the scalable case in this PR, or keep things as-is and fix it in a later follow up PR. 🙂
SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank)); | ||
auto bitcastResType = VectorType::get(shape, resultType.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You shouldn't need to copy this to a vector:
SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank)); | |
auto bitcastResType = VectorType::get(shape, resultType.getElementType()); | |
ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank); | |
auto bitcastResType = VectorType::get(shape, resultType.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> { | ||
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> | ||
return %0 : vector<2x2xi64> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to tests 1d and 0d vectors as well?
// TODO: Support the scalable vector cases. It is not supported because | ||
// the final rank could be values other than `targetRank`. It makes creating | ||
// the result type of new vector.bitcast ops much harder. | ||
if (resultType.isScalable()) | ||
return rewriter.notifyMatchFailure( | ||
op, "unrolling vector.bitcast on scalable vectors is NIY"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIY?
/// Example: | ||
/// | ||
/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32> | ||
/// | ||
/// Would be unrolled to: | ||
/// | ||
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32> | ||
/// %0 = vector.extract %a[0, 0, 0] ─┐ | ||
/// : vector<4xi64> from vector<1x2x3x4xi64> | | ||
/// %1 = vector.bitcast %0 | - Repeated 6x for | ||
/// : vector<4xi64> to vector<8xi32> | all leading positions | ||
/// %2 = vector.insert %1, %result [0, 0, 0] | | ||
/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this pattern is added to VectorToLLVM
with default targetRank = 1
, isn't this potentially a pessimization for programs that use vector.bitcast
with shapes where the inner-most dimension's size in bits is smaller than the target SIMD vector size in bit?
For example, looking at the example given in this comment, this causes vectors to be tiled into 256-bit vectors. When the target prefers 512-bit vectors, could this cause codegen of neighbouring ops to use 256-bit vectors instead of 512-bit vectors?
This example is relatively mild in this sense, since the inner dimension's size of 256 bits is still fairly large, but what if the inner dimension was smaller still, such as vector<4x3x2x1xi32>
.
Out of this consideration, I would expect vector unrolling patterns to default to targeting a certain target vector size in bits, rather than a certain target rank. So you could say that the target size in bits is the target's SIMD vector size, e.g. 512 bits, and keep enough inner dimensions in the vector tile to achieve that size. If a dimension is non-power-of-2, though, it's OK to stop there, as non-power-of-two vectors would have to be broken into smaller power-of-two tiles at codegen anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this pattern is added to VectorToLLVM with default targetRank = 1, isn't this potentially a pessimization for programs that use vector.bitcast with shapes where the inner-most dimension's size in bits is smaller than the target SIMD vector size in bit?
This kind of operation should be handled by OptimizeVectorShape pass. We have many things done for unit dims, to get rid of the issue you mentioned, and we are making progress on vector flattening: iree-org/iree#17530 We have some patterns to flatten/linearize vectors, which are expected to be run before unrolling: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/linearize.mlir
The unrolling pattern is more like complement of llvm e2e lowering story. Sad to say that, this op was added 4 years ago, and nobody is using it until 2024.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline.
The revision unrolls vector.bitcast like:
to
The scalable vector is not supported because of the limitation of
vector::createUnrollIterator
. The targetRank could mismatch the final rank during unrolling; there is no direct way to query what the final rank is from the object.