Skip to content

Conversation

@makslevental
Copy link
Contributor

@makslevental makslevental commented Dec 12, 2025

Enable FailOnUnsupportedFP for ConvertToLLVMPattern and set it to true for all math-to-llvm patterns. This fixes various invalid lowerings of math ops on fp8/fp4 types.

@makslevental makslevental force-pushed the users/makslevental/matpattern branch 3 times, most recently from b2571a4 to 7583a52 Compare December 12, 2025 18:35
@makslevental makslevental marked this pull request as ready for review December 12, 2025 18:36
@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Maksim Levental (makslevental)

Changes

Enable FailOnUnsupportedFP for ConvertToLLVMPattern.


Full diff: https://github.com/llvm/llvm-project/pull/172054.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+22-1)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+7-18)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+31)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (-21)
  • (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+40-17)
  • (modified) mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir (+13)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
                                const LLVMTypeConverter &typeConverter,
                                RewriterBase &rewriter);
 
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+                                        const TypeConverter &typeConverter);
 } // namespace detail
 
 /// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
 /// Utility class for operation conversions targeting the LLVM dialect that
 /// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
                            rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
                                     Attribute propertiesAttr,
                                     const LLVMTypeConverter &typeConverter,
                                     ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
-                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
               AttrConvertPassThrough,
           bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+    : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
 public:
-  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<SourceOp,
+                               FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
   LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     // Bail on unsupported floating point types. (These are type-converted to
     // integer types.)
-    if (FailOnUnsupportedFP) {
-      for (Value operand : op->getOperands())
-        if (LLVM::detail::isUnsupportedFloatingPointType(
-                *this->getTypeConverter(), operand.getType()))
-          return rewriter.notifyMatchFailure(op,
-                                             "unsupported floating point type");
-      if (LLVM::detail::isUnsupportedFloatingPointType(
-              *this->getTypeConverter(), op->getResult(0).getType()))
-        return rewriter.notifyMatchFailure(op,
-                                           "unsupported floating point type");
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
     }
 
     // Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..640ff3d7c3c7d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
                                    base, index, noWrapFlags)
              : base;
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+    Operation *op, const TypeConverter &typeConverter) {
+  for (Value operand : op->getOperands())
+    if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+      return true;
+  return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
+    return isUnsupportedFloatingPointType(typeConverter, r.getType());
+  });
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
-  if (auto floatType = dyn_cast<FloatType>(type))
-    return floatType;
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return dyn_cast<FloatType>(vecType.getElementType());
-  return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
-    const TypeConverter &typeConverter, Type type) {
-  FloatType floatType = getFloatingPointType(type);
-  if (!floatType)
-    return false;
-  Type convertedType = typeConverter.convertType(floatType);
-  if (!convertedType)
-    return true;
-  return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..faa4182943f67 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,9 +32,10 @@ namespace {
 template <typename SourceOp, typename TargetOp>
 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
 
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
 using ConvertFMFMathToLLVMPattern =
-    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+                               FailOnUnsupportedFP>;
 
 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
@@ -44,7 +45,9 @@ using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
 using CtPopFOpLowering =
-    VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
+    VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP*/ true>;
 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
@@ -76,8 +79,10 @@ using ATan2OpLowering =
 // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
 // may be better to separate the patterns.
 template <typename MathOp, typename LLVMOp>
-struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
-  using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+struct IntOpWithFlagLowering
+    : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      MathOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
   using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
 
   LogicalResult
@@ -122,8 +127,11 @@ using CountTrailingZerosOpLowering =
 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
 
 // A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
-  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+struct SincosOpLowering
+    : public ConvertOpToLLVMPattern<math::SincosOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::SincosOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
@@ -154,8 +162,11 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
 };
 
 // A `expm1` is converted into `exp - 1`.
-struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
-  using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+struct ExpM1OpLowering
+    : public ConvertOpToLLVMPattern<math::ExpM1Op,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::ExpM1Op, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
@@ -216,8 +227,11 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
 };
 
 // A `log1p` is converted into `log(1 + ...)`.
-struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
-  using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+struct Log1pOpLowering
+    : public ConvertOpToLLVMPattern<math::Log1pOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::Log1pOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
@@ -278,8 +292,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
 };
 
 // A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
-  using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
+struct RsqrtOpLowering
+    : public ConvertOpToLLVMPattern<math::RsqrtOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::RsqrtOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
@@ -339,8 +356,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
   }
 };
 
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
-  using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+    : public ConvertOpToLLVMPattern<math::IsNaNOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +378,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
   }
 };
 
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
-  using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+    : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f7d27120d4207..394aca876ff08 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -628,3 +628,16 @@ func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
   %3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+//       CHECK:   math.absf {{.*}} : f4E2M1FN
+//       CHECK:   math.cos {{.*}} : f4E2M1FN
+//       CHECK:   math.fma {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN) {
+  %0 = math.absf %arg0 : f4E2M1FN
+  %1 = math.cos %arg0 : f4E2M1FN
+  %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
+  return
+}
\ No newline at end of file

@makslevental makslevental force-pushed the users/makslevental/matpattern branch from 7583a52 to 25fdee8 Compare December 12, 2025 18:37
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
struct RsqrtOpLowering
: public ConvertOpToLLVMPattern<math::RsqrtOp,
/*FailOnUnsupportedFP*/ true> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Some comments say /*FailOnUnsupportedFP*/ true and some /*FailOnUnsupportedFP=*/true. (I'd recommend the latter.)

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention in the commit message that this fixes an invalid lowering.

@makslevental makslevental enabled auto-merge (squash) December 12, 2025 22:27
@makslevental makslevental merged commit 5361636 into main Dec 12, 2025
10 checks passed
@makslevental makslevental deleted the users/makslevental/matpattern branch December 12, 2025 22:31
anonymouspc pushed a commit to anonymouspc/llvm that referenced this pull request Dec 15, 2025
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to
`true` for all `math-to-llvm` patterns. This fixes various invalid
lowerings of `math` ops on `fp8`/`fp4` types.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Dec 19, 2025
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to
`true` for all `math-to-llvm` patterns. This fixes various invalid
lowerings of `math` ops on `fp8`/`fp4` types.
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Dec 20, 2025
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to
`true` for all `math-to-llvm` patterns. This fixes various invalid
lowerings of `math` ops on `fp8`/`fp4` types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants