diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 964281592cc65..cad6cec761ab8 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -92,12 +92,43 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; + /// Return the given type if it's a floating point type. If the given type is + /// a vector type, return its element type if it's a floating point type. + static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast(type)) + return floatType; + if (auto vecType = dyn_cast(type)) + return dyn_cast(vecType.getElementType()); + return nullptr; + } + LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); + + // The pattern should not apply if a floating-point operand is converted to + // a non-floating-point type. This indicates that the floating point type + // is not supported by the LLVM lowering. (Such types are converted to + // integers.) + auto checkType = [&](Value v) -> LogicalResult { + FloatType floatType = getFloatingPointType(v.getType()); + if (!floatType) + return success(); + Type convertedType = this->getTypeConverter()->convertType(floatType); + if (!isa_and_nonnull(convertedType)) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); + return success(); + }; + for (Value operand : op->getOperands()) + if (failed(checkType(operand))) + return failure(); + if (failed(checkType(op->getResult(0)))) + return failure(); + // Determine attributes for the target op AttrConvert attrConvert(op); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index ba12ff29ebef9..b5dcb01d3dc6b 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref) -> memref { %2 = arith.bitcast %1 : memref to memref func.return %2 : memref } + +// ----- + +// CHECK-LABEL: func @unsupported_fp_type +// CHECK: arith.addf {{.*}} : f4E2M1FN +// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> +// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> +func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) { + %0 = arith.addf %arg0, %arg0 : f4E2M1FN + %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> + %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> + return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN> +} + +// ----- + +// CHECK-LABEL: func @supported_fp_type +// CHECK: llvm.fadd {{.*}} : f32 +// CHECK: llvm.fadd {{.*}} : vector<4xf32> +// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32> +func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) { + %0 = arith.addf %arg0, %arg0 : f32 + %1 = arith.addf %arg1, %arg1 : vector<4xf32> + %2 = arith.addf %arg2, %arg2 : vector<4x8xf32> + return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32> +}