Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/test/Dialect/Vector/td/unroll-elements.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op

%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
transform.apply_patterns.vector.transfer_permutation_patterns
transform.apply_patterns to %func_op {
// Test patterns
transform.apply_patterns.vector.unroll_to_elements
transform.apply_patterns.vector.unroll_from_elements
} : !transform.any_op

transform.yield
}
}
1 change: 1 addition & 0 deletions mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module @transforms attributes { transform.with_named_sequence } {

%func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
// Test patterns
transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
} : !transform.op<"func.func">

Expand Down
49 changes: 25 additions & 24 deletions mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL
// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s

//===----------------------------------------------------------------------===//
// Test UnrollFromElements.
//===----------------------------------------------------------------------===//

// CHECK-UNROLL-LABEL: @unroll_from_elements_2d
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32>
// CHECK-LABEL: @unroll_from_elements_2d
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
// CHECK-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
// CHECK-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
// CHECK-NEXT: return %[[RES_1]] : vector<2x2xf32>
func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %0 : vector<2x2xf32>
}

// CHECK-UNROLL-LABEL: @unroll_from_elements_3d
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
// CHECK-LABEL: @unroll_from_elements_3d
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
// CHECK-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
// CHECK-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
// CHECK-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
// CHECK-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
// CHECK-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
// CHECK-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
// CHECK-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
// CHECK-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}

// 1-D vector.from_elements should not be unrolled.

// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d
// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32>
// CHECK-LABEL: @negative_unroll_from_elements_1d
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
// CHECK-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
// CHECK-NEXT: return %[[RES]] : vector<2xf32>
func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> {
%0 = vector.from_elements %arg0, %arg1 : vector<2xf32>
return %0 : vector<2xf32>
Expand Down
15 changes: 10 additions & 5 deletions mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s

// CHECK-LABEL: func.func @to_elements_1d(
//===----------------------------------------------------------------------===//
// Test UnrollToElements.
//===----------------------------------------------------------------------===//

// 1-D vector.from_elements should not be unrolled.

// CHECK-LABEL: func.func @negative_unroll_to_elements_1d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
// CHECK: return %[[RES]]#0, %[[RES]]#1
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
func.func @negative_unroll_to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
%0:2 = vector.to_elements %arg0 : vector<2xf32>
return %0#0, %0#1 : f32, f32
}

// -----

// CHECK-LABEL: func.func @to_elements_2d(
// CHECK-LABEL: func.func @unroll_to_elements_2d(
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
48 changes: 0 additions & 48 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,50 +756,6 @@ struct TestVectorGatherLowering
}
};

struct TestUnrollVectorFromElements
: public PassWrapper<TestUnrollVectorFromElements,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)

StringRef getArgument() const final {
return "test-unroll-vector-from-elements";
}
StringRef getDescription() const final {
return "Test unrolling patterns for from_elements ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorFromElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestUnrollVectorToElements
: public PassWrapper<TestUnrollVectorToElements,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)

StringRef getArgument() const final {
return "test-unroll-vector-to-elements";
}
StringRef getDescription() const final {
return "Test unrolling patterns for to_elements ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, vector::VectorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -1071,10 +1027,6 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorGatherLowering>();

PassRegistration<TestUnrollVectorFromElements>();

PassRegistration<TestUnrollVectorToElements>();

PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

PassRegistration<TestVectorEmulateMaskedLoadStore>();
Expand Down