Skip to content

Commit 0ee7c94

Browse files
authored
[mlir][vector] Tidy-up testing for to/from_elements unrolling (#158309)
1. Remove `TestUnrollVectorToElements` and `TestUnrollVectorFromElements` test passes - these are not required. 2. Make "vector-from-elements-lowering.mlir" use TD Op for testing (for consistency "vector-to-elements-lowering.mlir" and to make sure that the TD Op, `transform.apply_patterns.vector.unroll_from_elements`, is tested). 3. Unify `CHECK` prefixes (`CHECK-UNROLL` -> `CHECK`). 4. Rename `@to_elements_1d` as `@negative_unroll_to_elements_1d`, for consistency with it's counterpart for `vector.from_elements` and to align with our testing guide (*). (*) https://mlir.llvm.org/getting_started/TestingGuide/#after-step-3-add-the-newly-identified-missing-case
1 parent 3371375 commit 0ee7c94

File tree

5 files changed

+42
-80
lines changed

5 files changed

+42
-80
lines changed
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
module attributes {transform.with_named_sequence} {
22
transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
3-
%f = transform.structured.match ops{["func.func"]} in %module_op
3+
4+
%func_op = transform.structured.match ops{["func.func"]} in %module_op
45
: (!transform.any_op) -> !transform.any_op
5-
transform.apply_patterns to %f {
6-
transform.apply_patterns.vector.transfer_permutation_patterns
6+
transform.apply_patterns to %func_op {
7+
// Test patterns
78
transform.apply_patterns.vector.unroll_to_elements
9+
transform.apply_patterns.vector.unroll_from_elements
810
} : !transform.any_op
11+
912
transform.yield
1013
}
1114
}

mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module @transforms attributes { transform.with_named_sequence } {
33

44
%func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
55
transform.apply_patterns to %func_op {
6+
// Test patterns
67
transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
78
} : !transform.op<"func.func">
89

mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,45 @@
1-
// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL
1+
// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
2+
// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
23

34
//===----------------------------------------------------------------------===//
45
// Test UnrollFromElements.
56
//===----------------------------------------------------------------------===//
67

7-
// CHECK-UNROLL-LABEL: @unroll_from_elements_2d
8-
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
9-
// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
10-
// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
11-
// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
12-
// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
13-
// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
14-
// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32>
8+
// CHECK-LABEL: @unroll_from_elements_2d
9+
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
10+
// CHECK-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
11+
// CHECK-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
12+
// CHECK-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
13+
// CHECK-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
14+
// CHECK-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
15+
// CHECK-NEXT: return %[[RES_1]] : vector<2x2xf32>
1516
func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
1617
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
1718
return %0 : vector<2x2xf32>
1819
}
1920

20-
// CHECK-UNROLL-LABEL: @unroll_from_elements_3d
21-
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
22-
// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
23-
// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
24-
// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
25-
// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
26-
// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
27-
// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
28-
// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
29-
// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
30-
// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
21+
// CHECK-LABEL: @unroll_from_elements_3d
22+
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
23+
// CHECK-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
24+
// CHECK-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
25+
// CHECK-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
26+
// CHECK-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
27+
// CHECK-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
28+
// CHECK-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
29+
// CHECK-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
30+
// CHECK-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
31+
// CHECK-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
3132
func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
3233
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
3334
return %0 : vector<2x1x2xf32>
3435
}
3536

3637
// 1-D vector.from_elements should not be unrolled.
3738

38-
// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d
39-
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
40-
// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
41-
// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32>
39+
// CHECK-LABEL: @negative_unroll_from_elements_1d
40+
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
41+
// CHECK-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
42+
// CHECK-NEXT: return %[[RES]] : vector<2xf32>
4243
func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> {
4344
%0 = vector.from_elements %arg0, %arg1 : vector<2xf32>
4445
return %0 : vector<2xf32>
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
1-
// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
21
// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
32
// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
43

5-
// CHECK-LABEL: func.func @to_elements_1d(
4+
//===----------------------------------------------------------------------===//
5+
// Test UnrollToElements.
6+
//===----------------------------------------------------------------------===//
7+
8+
// 1-D vector.from_elements should not be unrolled.
9+
10+
// CHECK-LABEL: func.func @negative_unroll_to_elements_1d(
611
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
712
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
813
// CHECK: return %[[RES]]#0, %[[RES]]#1
9-
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
14+
func.func @negative_unroll_to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
1015
%0:2 = vector.to_elements %arg0 : vector<2xf32>
1116
return %0#0, %0#1 : f32, f32
1217
}
1318

1419
// -----
1520

16-
// CHECK-LABEL: func.func @to_elements_2d(
21+
// CHECK-LABEL: func.func @unroll_to_elements_2d(
1722
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
1823
// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
1924
// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
2025
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
2126
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
2227
// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
23-
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
28+
func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
2429
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
2530
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
2631
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -756,50 +756,6 @@ struct TestVectorGatherLowering
756756
}
757757
};
758758

759-
struct TestUnrollVectorFromElements
760-
: public PassWrapper<TestUnrollVectorFromElements,
761-
OperationPass<func::FuncOp>> {
762-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)
763-
764-
StringRef getArgument() const final {
765-
return "test-unroll-vector-from-elements";
766-
}
767-
StringRef getDescription() const final {
768-
return "Test unrolling patterns for from_elements ops";
769-
}
770-
void getDependentDialects(DialectRegistry &registry) const override {
771-
registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
772-
}
773-
774-
void runOnOperation() override {
775-
RewritePatternSet patterns(&getContext());
776-
populateVectorFromElementsLoweringPatterns(patterns);
777-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
778-
}
779-
};
780-
781-
struct TestUnrollVectorToElements
782-
: public PassWrapper<TestUnrollVectorToElements,
783-
OperationPass<func::FuncOp>> {
784-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
785-
786-
StringRef getArgument() const final {
787-
return "test-unroll-vector-to-elements";
788-
}
789-
StringRef getDescription() const final {
790-
return "Test unrolling patterns for to_elements ops";
791-
}
792-
void getDependentDialects(DialectRegistry &registry) const override {
793-
registry.insert<func::FuncDialect, vector::VectorDialect>();
794-
}
795-
796-
void runOnOperation() override {
797-
RewritePatternSet patterns(&getContext());
798-
populateVectorToElementsLoweringPatterns(patterns);
799-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
800-
}
801-
};
802-
803759
struct TestFoldArithExtensionIntoVectorContractPatterns
804760
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
805761
OperationPass<func::FuncOp>> {
@@ -1071,10 +1027,6 @@ void registerTestVectorLowerings() {
10711027

10721028
PassRegistration<TestVectorGatherLowering>();
10731029

1074-
PassRegistration<TestUnrollVectorFromElements>();
1075-
1076-
PassRegistration<TestUnrollVectorToElements>();
1077-
10781030
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
10791031

10801032
PassRegistration<TestVectorEmulateMaskedLoadStore>();

0 commit comments

Comments
 (0)