-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][vector] Add n-d deinterleave lowering (#94237)
This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8> ``` To: ``` %cst = arith.constant dense<0> : vector<2x4xi32> %0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32> %res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32> %1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32> %2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32> %3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32> %res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32> %4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32> %5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32> ...etc. ```
- Loading branch information
1 parent
8719cb8
commit b87a80d
Showing
3 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s | ||
|
||
// CHECK-LABEL: @vector_deinterleave_2d | ||
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) | ||
func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) { | ||
// CHECK: %[[CST:.*]] = arith.constant dense<0> | ||
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0] | ||
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] | ||
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0] | ||
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0] | ||
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1] | ||
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]] | ||
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1] | ||
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1] | ||
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32> | ||
%0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32> | ||
return %0, %1 : vector<2x4xi32>, vector<2x4xi32> | ||
} | ||
|
||
// CHECK-LABEL: @vector_deinterleave_2d_scalable | ||
// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) | ||
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) { | ||
// CHECK: %[[CST:.*]] = arith.constant dense<0> | ||
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0] | ||
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] | ||
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0] | ||
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0] | ||
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1] | ||
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]] | ||
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1] | ||
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1] | ||
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32> | ||
%0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32> | ||
return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32> | ||
} | ||
|
||
// CHECK-LABEL: @vector_deinterleave_4d | ||
// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) | ||
func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) { | ||
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64> | ||
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64> | ||
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64> | ||
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64> | ||
// CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64> | ||
%0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64> | ||
return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64> | ||
} | ||
|
||
// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim | ||
func.func @vector_deinterleave_nd_with_scalable_dim( | ||
%a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) { | ||
// The scalable dim blocks unrolling so only the first two dims are unrolled. | ||
// CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16> | ||
%0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16> | ||
return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16> | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { | ||
%f = transform.structured.match ops{["func.func"]} in %module_op | ||
: (!transform.any_op) -> !transform.any_op | ||
|
||
transform.apply_patterns to %f { | ||
transform.apply_patterns.vector.lower_interleave | ||
} : !transform.any_op | ||
transform.yield | ||
} | ||
} |