-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Add LinearizeVectorToElements #157740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add LinearizeVectorToElements #157740
Conversation
@llvm/pr-subscribers-mlir-vector Author: Erick Ochoa Lopez (amd-eochoalo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/157740.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final
}
};
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorToElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType vecType = toElementsOp.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ toElementsOp, "the rank is already less than or equal to 1");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(
+ rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+ toElementsOp, toElementsOp.getResultTypes(), shapeCast);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
|
@llvm/pr-subscribers-mlir Author: Erick Ochoa Lopez (amd-eochoalo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/157740.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final
}
};
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorToElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType vecType = toElementsOp.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ toElementsOp, "the rank is already less than or equal to 1");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "scalable vector is not yet supported");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(
+ rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+ rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+ toElementsOp, toElementsOp.getResultTypes(), shapeCast);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
|
: OpConversionPattern(typeConverter, context, benefit) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to use the adaptor, but I noticed that in the pattern above adaptor is not used. To continue the style of the pattern above I also did not use it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only minor comments. LGTM; but please give others some time to review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Co-authored-by: James Newling <james.newling@gmail.com>
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling.
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
No description provided.