diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 1b656d82f3201..ea93085849e0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -817,6 +817,50 @@ struct LinearizeVectorToElements final } }; +/// Convert broadcasts from scalars or 1-element vectors, such as +/// +/// ```mlir +/// vector.broadcast %value : f32 to vector<4x4xf32> +/// ``` +/// +/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed. +/// The above becomes, +/// +/// ```mlir +/// %out_1d = vector.broadcast %value : f32 to vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// ``` +struct LinearizeVectorBroadcast final + : public OpConversionPattern { + using Base::Base; + + LinearizeVectorBroadcast(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + int numElements = 1; + Type sourceType = broadcastOp.getSourceType(); + if (auto vecType = dyn_cast(sourceType)) { + numElements = vecType.getNumElements(); + } + + if (numElements != 1) { + return rewriter.notifyMatchFailure( + broadcastOp, "only broadcasts of single elements can be linearized."); + } + + auto dstTy = getTypeConverter()->convertType(broadcastOp.getType()); + rewriter.replaceOpWithNewOp(broadcastOp, dstTy, + adaptor.getSource()); + + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorBroadcast, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index ee5cfbcda5c19..cbbc833d7a51d 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { // ----- +// CHECK-LABEL: linearize_vector_broadcast_scalar_source +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_vector_broadcast_rank_two_source +// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { + + // CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> + // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST1]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_scalable_vector_broadcast +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> +func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32> + // CHECK: return %[[CAST]] : vector<4x[2]xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32> + return %0 : vector<4x[2]xi32> + +} + +// ----- + // CHECK-LABEL: linearize_create_mask // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {