Skip to content

Commit

Permalink
[mlir][vector] Add n-d deinterleave lowering (#94237)
Browse files Browse the repository at this point in the history
This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:
```
%0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8>
```

To:
```
%cst = arith.constant dense<0> : vector<2x4xi32>
%0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32>
%res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32>
%1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32>
%2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32>
%3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32>
%res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32>
%4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32>
%5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32>
...etc.
```
  • Loading branch information
mub-at-arm committed Jun 7, 2024
1 parent 8719cb8 commit b87a80d
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
70 changes: 69 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};

/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
///
/// Example:
///
/// ```mlir
/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
/// ```
/// Would be unrolled to:
/// ```mlir
/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
/// %0 = vector.extract %a[0, 0, 0] ─┐
/// : vector<8xi64> from vector<1x2x3x8xi64> |
/// %1, %2 = vector.deinterleave %0 |
/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave
/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled.
/// : vector<4xi64> into vector<1x2x3x4xi64> |
/// %4 = vector.insert %2, %result [0, 0, 0] |
/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
/// %5 = vector.extract %a[0, 0, 1] ─┐
/// : vector<8xi64> from vector<1x2x3x8xi64> |
/// %6, %7 = vector.deinterleave %5 |
/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for
/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled
/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave
/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated
/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case.
/// ```
///
/// Note: If any leading dimension before the `targetRank` is scalable the
/// unrolling will stop before the scalable dimension.
class UnrollDeinterleaveOp final
: public OpRewritePattern<vector::DeinterleaveOp> {
public:
UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), targetRank(targetRank) {};

LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
PatternRewriter &rewriter) const override {
VectorType resultType = op.getResultVectorType();
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
if (!unrollIterator)
return failure();

auto loc = op.getLoc();
Value emptyResult = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));
Value evenResult = emptyResult;
Value oddResult = emptyResult;

for (auto position : *unrollIterator) {
auto extractSrc =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
auto deinterleave =
rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
evenResult = rewriter.create<vector::InsertOp>(
loc, deinterleave.getRes1(), evenResult, position);
oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
oddResult, position);
}
rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
return success();
}

private:
int64_t targetRank = 1;
};
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
/// applicable: `sourceType` must be 1D and non-scalable.
///
Expand Down Expand Up @@ -116,7 +183,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {

void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
targetRank, patterns.getContext(), benefit);
}

void mlir::vector::populateVectorInterleaveToShufflePatterns(
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}

// CHECK-LABEL: @vector_deinterleave_2d
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
// CHECK: llvm.shufflevector
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
%0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
}

func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
// CHECK: llvm.intr.vector.deinterleave2
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
%0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
}

// -----

// CHECK-LABEL: func.func @vector_bitcast_2d
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s

// CHECK-LABEL: @vector_deinterleave_2d
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0>
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
%0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
}

// CHECK-LABEL: @vector_deinterleave_2d_scalable
// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0>
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
%0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
}

// CHECK-LABEL: @vector_deinterleave_4d
// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
// CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
%0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
}

// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
func.func @vector_deinterleave_nd_with_scalable_dim(
%a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
// The scalable dim blocks unrolling so only the first two dims are unrolled.
// CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
%0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_interleave
} : !transform.any_op
transform.yield
}
}

0 comments on commit b87a80d

Please sign in to comment.