From d4251875f8c346634c564e5498f40f86b33bd20f Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 12:46:01 -0700 Subject: [PATCH 1/3] [mlir][vector] Add LinearizeVectorToElements --- .../Vector/Transforms/VectorLinearize.cpp | 47 ++++++++++++++++++- mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7dde6311fa809..54eb182a9680f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final } }; +/// This pattern linearizes the operand in `vector.to_elements` operations +/// by converting the result type to a 1-D vector while preserving all element +/// values. The transformation creates a linearized `vector.shape_cast` +/// followed by a `vector.to_elements`. +/// +/// Example: +/// +/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// +/// is converted to: +/// +/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32> +/// %0:4 = vector.to_elements %vector_cast : vector<4xf32> +/// +struct LinearizeVectorToElements final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorToElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType vecType = toElementsOp.getSource().getType(); + if (vecType.getRank() <= 1) + return rewriter.notifyMatchFailure( + toElementsOp, "the rank is already less than or equal to 1"); + + assert(vecType.getNumScalableDims() == 0 && + "scalable vector is not yet supported"); + auto vec1DType = + VectorType::get({vecType.getNumElements()}, vecType.getElementType()); + Value shapeCast = vector::ShapeCastOp::create( + rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); + rewriter.replaceOpWithNewOp( + toElementsOp, toElementsOp.getResultTypes(), shapeCast); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 5e8bfd0698b33..fe697c8b9c057 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> return %1 : vector<2x2xf32> } + +// ----- + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] +// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} From 2e4399595eee79d3487947326dbb0516e2f1e262 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 9 Sep 2025 13:27:35 -0700 Subject: [PATCH 2/3] Do not use builder create method --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 54eb182a9680f..763c738d66c3a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -835,8 +835,10 @@ struct LinearizeVectorToElements final VectorType::get({vecType.getNumElements()}, vecType.getElementType()); Value shapeCast = vector::ShapeCastOp::create( rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); - rewriter.replaceOpWithNewOp( - toElementsOp, toElementsOp.getResultTypes(), shapeCast); + auto newToElementsOp = + vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(), + toElementsOp.getResultTypes(), shapeCast); + rewriter.replaceOp(toElementsOp, newToElementsOp); return success(); } }; From 51c633c07d2dbb3fd9df683ed90884a6a88444be Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 10 Sep 2025 09:22:49 -0400 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: James Newling --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 763c738d66c3a..12acf4b3f07f5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -799,7 +799,7 @@ struct LinearizeVectorFromElements final }; /// This pattern linearizes the operand in `vector.to_elements` operations -/// by converting the result type to a 1-D vector while preserving all element +/// by converting the source type to a 1-D vector while preserving all element /// values. The transformation creates a linearized `vector.shape_cast` /// followed by a `vector.to_elements`. /// @@ -830,7 +830,7 @@ struct LinearizeVectorToElements final toElementsOp, "the rank is already less than or equal to 1"); assert(vecType.getNumScalableDims() == 0 && - "scalable vector is not yet supported"); + "to_elements does not support scalable vectors"); auto vec1DType = VectorType::get({vecType.getNumElements()}, vecType.getElementType()); Value shapeCast = vector::ShapeCastOp::create(