diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index e404d75f7d243..cb74987024138 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -145,11 +145,11 @@ struct VectorFmaOpConvert final : public OpConversionPattern { LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) + Type dstType = getTypeConverter()->convertType(fmaOp.getType()); + if (!dstType) return failure(); rewriter.replaceOpWithNewOp( - fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(), - adaptor.getAcc()); + fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; @@ -321,13 +321,18 @@ class VectorSplatPattern final : public OpConversionPattern { LogicalResult matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType dstVecType = op.getType(); - if (!spirv::CompositeType::isValid(dstVecType)) + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) return failure(); - SmallVector source(dstVecType.getNumElements(), - adaptor.getInput()); - rewriter.replaceOpWithNewOp(op, dstVecType, - source); + if (dstType.isa()) { + rewriter.replaceOp(op, adaptor.getInput()); + } else { + auto dstVecType = dstType.cast(); + SmallVector source(dstVecType.getNumElements(), + adaptor.getInput()); + rewriter.replaceOpWithNewOp(op, dstType, + source); + } return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8a3f71b380c4c..76ce449c51bde 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -171,6 +171,15 @@ func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vecto // ----- +// CHECK-LABEL: @fma_size1_vector +// CHECK: spv.GLSL.Fma %{{.+}} : f32 +func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> { + %0 = vector.fma %a, %b, %c: vector<1xf32> + return %0 : vector<1xf32> +} + +// ----- + // CHECK-LABEL: func @splat // CHECK-SAME: (%[[A:.+]]: f32) // CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> @@ -182,6 +191,17 @@ func.func @splat(%f : f32) -> vector<4xf32> { // ----- +// CHECK-LABEL: func @splat_size1_vector +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %[[A]] +// CHECK: return %[[VAL]] +func.func @splat_size1_vector(%f : f32) -> vector<1xf32> { + %splat = vector.splat %f : vector<1xf32> + return %splat : vector<1xf32> +} + +// ----- + // CHECK-LABEL: func @shuffle // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32> // CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]