diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 491b448e9e1e9..7dde6311fa809 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -762,6 +762,42 @@ struct LinearizeVectorStore final } }; +/// This pattern linearizes `vector.from_elements` operations by converting +/// the result type to a 1-D vector while preserving all element values. +/// The transformation creates a linearized `vector.from_elements` followed by +/// a `vector.shape_cast` to restore the original multidimensional shape. +/// +/// Example: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32> +/// +/// is converted to: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32> +/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> +/// +struct LinearizeVectorFromElements final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorFromElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + LogicalResult + matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType dstTy = + getTypeConverter()->convertType(fromElementsOp.getType()); + assert(dstTy && "vector type destination expected."); + + OperandRange elements = fromElementsOp.getElements(); + assert(elements.size() == static_cast(dstTy.getNumElements()) && + "expected same number of elements"); + rewriter.replaceOpWithNewOp(fromElementsOp, dstTy, + elements); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add(typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 2e630bf93622e..5e8bfd0698b33 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32> return } + +// ----- + +// Test pattern LinearizeVectorFromElements. + +// CHECK-LABEL: test_vector_from_elements +// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32 +func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> { + // CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32> + // CHECK: return %[[CAST]] : vector<2x2xf32> + %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> + return %1 : vector<2x2xf32> +}