Skip to content

[mlir][vector] Add support for unrolling vector.bitcast ops. #94064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_bitcast",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector bitcast operations should be lowered to
finer-grained vector primitives.

This is usally a late step that is run after bufferization as part of the
process of lowering to e.g. LLVM or NVVM.
}];

let assemblyFormat = "attr-dict";
}

def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_broadcast",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populates the pattern set with the following patterns:
///
/// [UnrollBitCastOp]
/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
/// BitCastOp (of `targetRank`) + InsertOp.
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);

} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}

void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
}

void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorBroadcastLoweringPatterns(patterns);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
Expand Down
96 changes: 96 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.bitcast' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"

#define DEBUG_TYPE "vector-bitcast-lowering"

using namespace mlir;
using namespace mlir::vector;

namespace {

/// A one-shot unrolling of vector.bitcast to the `targetRank`.
///
/// Example:
///
/// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
///
/// Would be unrolled to:
///
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
/// %0 = vector.extract %a[0, 0, 0] ─┐
/// : vector<4xi64> from vector<1x2x3x4xi64> |
/// %1 = vector.bitcast %0 | - Repeated 6x for
/// : vector<4xi64> to vector<8xi32> | all leading positions
/// %2 = vector.insert %1, %result [0, 0, 0] |
/// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
Comment on lines +30 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this pattern is added to VectorToLLVM with default targetRank = 1, isn't this potentially a pessimization for programs that use vector.bitcast with shapes where the inner-most dimension's size in bits is smaller than the target SIMD vector size in bit?

For example, looking at the example given in this comment, this causes vectors to be tiled into 256-bit vectors. When the target prefers 512-bit vectors, could this cause codegen of neighbouring ops to use 256-bit vectors instead of 512-bit vectors?

This example is relatively mild in this sense, since the inner dimension's size of 256 bits is still fairly large, but what if the inner dimension was smaller still, such as vector<4x3x2x1xi32>.

Out of this consideration, I would expect vector unrolling patterns to default to targeting a certain target vector size in bits, rather than a certain target rank. So you could say that the target size in bits is the target's SIMD vector size, e.g. 512 bits, and keep enough inner dimensions in the vector tile to achieve that size. If a dimension is non-power-of-2, though, it's OK to stop there, as non-power-of-two vectors would have to be broken into smaller power-of-two tiles at codegen anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this pattern is added to VectorToLLVM with default targetRank = 1, isn't this potentially a pessimization for programs that use vector.bitcast with shapes where the inner-most dimension's size in bits is smaller than the target SIMD vector size in bit?

This kind of operation should be handled by OptimizeVectorShape pass. We have many things done for unit dims, to get rid of the issue you mentioned, and we are making progress on vector flattening: iree-org/iree#17530 We have some patterns to flatten/linearize vectors, which are expected to be run before unrolling: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/linearize.mlir

The unrolling pattern is more like complement of llvm e2e lowering story. Sad to say that, this op was added 4 years ago, and nobody is using it until 2024.

///
/// Note: If any leading dimension before the `targetRank` is scalable the
/// unrolling will stop before the scalable dimension.
class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
public:
UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), targetRank(targetRank) {};

LogicalResult matchAndRewrite(vector::BitCastOp op,
PatternRewriter &rewriter) const override {
VectorType resultType = op.getResultVectorType();
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
if (!unrollIterator)
return failure();

// TODO: Support the scalable vector cases. It is not supported because
// the final rank could be values other than `targetRank`. It makes creating
// the result type of new vector.bitcast ops much harder.
if (resultType.isScalable()) {
return rewriter.notifyMatchFailure(op,
"unrolling vector.bitcast on scalable "
"vectors is not yet implemented");
}

ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
auto bitcastResType = VectorType::get(shape, resultType.getElementType());

Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));
for (auto position : *unrollIterator) {
Value extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
Value bitcast =
rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
result =
rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
}

rewriter.replaceOp(op, result);
return success();
}

private:
int64_t targetRank = 1;
};

} // namespace

void mlir::vector::populateVectorBitCastLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
}
10 changes: 10 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}

// -----

// CHECK-LABEL: func.func @vector_bitcast_2d
// CHECK: llvm.bitcast
// CHECK-NOT: vector.bitcast
func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
return %0 : vector<2x2xi64>
}
53 changes: 53 additions & 0 deletions mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s

func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> {
%0 = vector.bitcast %arg0 : vector<i32> to vector<f32>
return %0 : vector<f32>
}
// CHECK-LABEL: func.func @vector_bitcast_0d
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32>
// CHECK: return %[[RES]]

func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> {
%0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32>
return %0 : vector<20xi32>
}
// CHECK-LABEL: func.func @vector_bitcast_1d
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32>
// CHECK: return %[[RES]]

func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
return %0 : vector<2x2xi64>
}
Comment on lines +21 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to tests 1d and 0d vectors as well?

// CHECK-LABEL: func.func @vector_bitcast_2d
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
// CHECK: return %[[R2]]

func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
%0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
return %0 : vector<1x2x[3]x8xi32>
}
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_bitcast
} : !transform.any_op
transform.yield
}
}