From a6287dcc2f904a5e1ac8f25dea6747b390468198 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 6 Nov 2025 01:29:04 +0000 Subject: [PATCH] [mlir][arith] Fix `arith.cmpf` lowering with unsupported FP types --- .../Conversion/LLVMCommon/VectorPattern.h | 44 +++++++------------ .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 4 ++ .../Conversion/LLVMCommon/VectorPattern.cpp | 21 +++++++++ .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 10 +++-- 4 files changed, 48 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 47b8381eefda8..32dd8ba2bc391 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -60,6 +60,12 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + +/// Return "true" if the given type is an unsupported floating point type. In +/// case of a vector type, return "true" if the element type is an unsupported +/// floating point type. +bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, + Type type); } // namespace detail } // namespace LLVM @@ -97,16 +103,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; - /// Return the given type if it's a floating point type. If the given type is - /// a vector type, return its element type if it's a floating point type. - static FloatType getFloatingPointType(Type type) { - if (auto floatType = dyn_cast(type)) - return floatType; - if (auto vecType = dyn_cast(type)) - return dyn_cast(vecType.getElementType()); - return nullptr; - } - LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -114,26 +110,18 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { std::is_base_of, SourceOp>::value, "expected single result op"); - // The pattern should not apply if a floating-point operand is converted to - // a non-floating-point type. This indicates that the floating point type - // is not supported by the LLVM lowering. (Such types are converted to - // integers.) - auto checkType = [&](Value v) -> LogicalResult { - FloatType floatType = getFloatingPointType(v.getType()); - if (!floatType) - return success(); - Type convertedType = this->getTypeConverter()->convertType(floatType); - if (!isa_and_nonnull(convertedType)) - return rewriter.notifyMatchFailure(op, - "unsupported floating point type"); - return success(); - }; + // Bail on unsupported floating point types. (These are type-converted to + // integer types.) if (FailOnUnsupportedFP) { for (Value operand : op->getOperands()) - if (failed(checkType(operand))) - return failure(); - if (failed(checkType(op->getResult(0)))) - return failure(); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), operand.getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), op->getResult(0).getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); } // Determine attributes for the target op diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index cc3e8468f298b..220826dc5f3ac 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -483,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index 24b01259f0499..e5969c2539566 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -130,3 +130,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast(type)) + return floatType; + if (auto vecType = dyn_cast(type)) + return dyn_cast(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa(convertedType); +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 6fdc1104d2609..b53c52d75c0aa 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -770,12 +770,14 @@ func.func @memref_bitcast(%1: memref) -> memref { // CHECK: arith.addf {{.*}} : f4E2M1FN // CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> // CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> +// CHECK: arith.cmpf {{.*}} : f4E2M1FN // CHECK: llvm.select {{.*}} : i1, i4 func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) { %0 = arith.addf %arg0, %arg0 : f4E2M1FN %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> - %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN + %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN + %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN return } @@ -785,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2 // CHECK: llvm.fadd {{.*}} : f32 // CHECK: llvm.fadd {{.*}} : vector<4xf32> // CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32> -func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) { +// CHECK: llvm.fcmp {{.*}} : f32 +func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) { %0 = arith.addf %arg0, %arg0 : f32 %1 = arith.addf %arg1, %arg1 : vector<4xf32> %2 = arith.addf %arg2, %arg2 : vector<4x8xf32> - return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32> + %3 = arith.cmpf oeq, %arg0, %arg3 : f32 + return }