-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][vector] Add support for unrolling vector.bitcast ops. #94064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements target-independent rewrites and utilities to lower the | ||
// 'vector.bitcast' operation. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
#define DEBUG_TYPE "vector-bitcast-lowering" | ||
|
||
using namespace mlir; | ||
using namespace mlir::vector; | ||
|
||
namespace { | ||
|
||
/// A one-shot unrolling of vector.bitcast to the `targetRank`. | ||
/// | ||
/// Example: | ||
/// | ||
/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32> | ||
/// | ||
/// Would be unrolled to: | ||
/// | ||
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32> | ||
/// %0 = vector.extract %a[0, 0, 0] ─┐ | ||
/// : vector<4xi64> from vector<1x2x3x4xi64> | | ||
/// %1 = vector.bitcast %0 | - Repeated 6x for | ||
/// : vector<4xi64> to vector<8xi32> | all leading positions | ||
/// %2 = vector.insert %1, %result [0, 0, 0] | | ||
/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘ | ||
/// | ||
/// Note: If any leading dimension before the `targetRank` is scalable the | ||
/// unrolling will stop before the scalable dimension. | ||
class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> { | ||
public: | ||
UnrollBitCastOp(int64_t targetRank, MLIRContext *context, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern(context, benefit), targetRank(targetRank) {}; | ||
|
||
LogicalResult matchAndRewrite(vector::BitCastOp op, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = op.getResultVectorType(); | ||
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank); | ||
if (!unrollIterator) | ||
return failure(); | ||
|
||
// TODO: Support the scalable vector cases. It is not supported because | ||
// the final rank could be values other than `targetRank`. It makes creating | ||
// the result type of new vector.bitcast ops much harder. | ||
if (resultType.isScalable()) { | ||
return rewriter.notifyMatchFailure(op, | ||
"unrolling vector.bitcast on scalable " | ||
"vectors is not yet implemented"); | ||
} | ||
|
||
ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank); | ||
auto bitcastResType = VectorType::get(shape, resultType.getElementType()); | ||
|
||
Location loc = op.getLoc(); | ||
Value result = rewriter.create<arith::ConstantOp>( | ||
loc, resultType, rewriter.getZeroAttr(resultType)); | ||
for (auto position : *unrollIterator) { | ||
Value extract = | ||
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position); | ||
Value bitcast = | ||
rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract); | ||
result = | ||
rewriter.create<vector::InsertOp>(loc, bitcast, result, position); | ||
} | ||
|
||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
|
||
private: | ||
int64_t targetRank = 1; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::vector::populateVectorBitCastLoweringPatterns( | ||
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { | ||
patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s | ||
|
||
func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> { | ||
%0 = vector.bitcast %arg0 : vector<i32> to vector<f32> | ||
return %0 : vector<f32> | ||
} | ||
// CHECK-LABEL: func.func @vector_bitcast_0d | ||
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] | ||
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32> | ||
// CHECK: return %[[RES]] | ||
|
||
func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> { | ||
%0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32> | ||
return %0 : vector<20xi32> | ||
} | ||
// CHECK-LABEL: func.func @vector_bitcast_1d | ||
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] | ||
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32> | ||
// CHECK: return %[[RES]] | ||
|
||
func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> { | ||
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> | ||
return %0 : vector<2x2xi64> | ||
} | ||
Comment on lines
+21
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to tests 1d and 0d vectors as well? |
||
// CHECK-LABEL: func.func @vector_bitcast_2d | ||
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] | ||
// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64> | ||
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32> | ||
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64> | ||
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] | ||
// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32> | ||
// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64> | ||
// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] | ||
// CHECK: return %[[R2]] | ||
|
||
func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> { | ||
%0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> | ||
return %0 : vector<1x2x[3]x8xi32> | ||
} | ||
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim | ||
// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32> | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { | ||
%f = transform.structured.match ops{["func.func"]} in %module_op | ||
: (!transform.any_op) -> !transform.any_op | ||
|
||
transform.apply_patterns to %f { | ||
transform.apply_patterns.vector.lower_bitcast | ||
} : !transform.any_op | ||
transform.yield | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this pattern is added to
VectorToLLVM
with defaulttargetRank = 1
, isn't this potentially a pessimization for programs that usevector.bitcast
with shapes where the inner-most dimension's size in bits is smaller than the target SIMD vector size in bit?For example, looking at the example given in this comment, this causes vectors to be tiled into 256-bit vectors. When the target prefers 512-bit vectors, could this cause codegen of neighbouring ops to use 256-bit vectors instead of 512-bit vectors?
This example is relatively mild in this sense, since the inner dimension's size of 256 bits is still fairly large, but what if the inner dimension was smaller still, such as
vector<4x3x2x1xi32>
.Out of this consideration, I would expect vector unrolling patterns to default to targeting a certain target vector size in bits, rather than a certain target rank. So you could say that the target size in bits is the target's SIMD vector size, e.g. 512 bits, and keep enough inner dimensions in the vector tile to achieve that size. If a dimension is non-power-of-2, though, it's OK to stop there, as non-power-of-two vectors would have to be broken into smaller power-of-two tiles at codegen anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kind of operation should be handled by OptimizeVectorShape pass. We have many things done for unit dims, to get rid of the issue you mentioned, and we are making progress on vector flattening: iree-org/iree#17530 We have some patterns to flatten/linearize vectors, which are expected to be run before unrolling: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/linearize.mlir
The unrolling pattern is more like complement of llvm e2e lowering story. Sad to say that, this op was added 4 years ago, and nobody is using it until 2024.