-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[mlir][LLVM] refactor FailOnUnsupportedFP #172054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b2571a4 to
7583a52
Compare
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Maksim Levental (makslevental) ChangesEnable Full diff: https://github.com/llvm/llvm-project/pull/172054.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter);
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+ const TypeConverter &typeConverter);
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
- Type type);
} // namespace detail
} // namespace LLVM
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+ : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
public:
- using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+ using ConvertOpToLLVMPattern<SourceOp,
+ FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
// Bail on unsupported floating point types. (These are type-converted to
// integer types.)
- if (FailOnUnsupportedFP) {
- for (Value operand : op->getOperands())
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), operand.getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), op->getResult(0).getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
}
// Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..640ff3d7c3c7d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
base, index, noWrapFlags)
: base;
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ Operation *op, const TypeConverter &typeConverter) {
+ for (Value operand : op->getOperands())
+ if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+ return true;
+ return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
+ return isUnsupportedFloatingPointType(typeConverter, r.getType());
+ });
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
- const TypeConverter &typeConverter, Type type) {
- FloatType floatType = getFloatingPointType(type);
- if (!floatType)
- return false;
- Type convertedType = typeConverter.convertType(floatType);
- if (!convertedType)
- return true;
- return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..faa4182943f67 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,9 +32,10 @@ namespace {
template <typename SourceOp, typename TargetOp>
using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
using ConvertFMFMathToLLVMPattern =
- VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+ VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+ FailOnUnsupportedFP>;
using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
@@ -44,7 +45,9 @@ using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
using CtPopFOpLowering =
- VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
+ VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP*/ true>;
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
@@ -76,8 +79,10 @@ using ATan2OpLowering =
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
// may be better to separate the patterns.
template <typename MathOp, typename LLVMOp>
-struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
- using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+struct IntOpWithFlagLowering
+ : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ MathOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
LogicalResult
@@ -122,8 +127,11 @@ using CountTrailingZerosOpLowering =
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
- using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+struct SincosOpLowering
+ : public ConvertOpToLLVMPattern<math::SincosOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::SincosOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
@@ -154,8 +162,11 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
};
// A `expm1` is converted into `exp - 1`.
-struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
- using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+struct ExpM1OpLowering
+ : public ConvertOpToLLVMPattern<math::ExpM1Op,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::ExpM1Op, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
@@ -216,8 +227,11 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
};
// A `log1p` is converted into `log(1 + ...)`.
-struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
- using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+struct Log1pOpLowering
+ : public ConvertOpToLLVMPattern<math::Log1pOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::Log1pOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
@@ -278,8 +292,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
};
// A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
- using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
+struct RsqrtOpLowering
+ : public ConvertOpToLLVMPattern<math::RsqrtOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::RsqrtOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
@@ -339,8 +356,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
}
};
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
- using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+ : public ConvertOpToLLVMPattern<math::IsNaNOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +378,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
}
};
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
- using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+ : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f7d27120d4207..394aca876ff08 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -628,3 +628,16 @@ func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
%3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+// CHECK: math.absf {{.*}} : f4E2M1FN
+// CHECK: math.cos {{.*}} : f4E2M1FN
+// CHECK: math.fma {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN) {
+ %0 = math.absf %arg0 : f4E2M1FN
+ %1 = math.cos %arg0 : f4E2M1FN
+ %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
+ return
+}
\ No newline at end of file
|
7583a52 to
25fdee8
Compare
| using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; | ||
| struct RsqrtOpLowering | ||
| : public ConvertOpToLLVMPattern<math::RsqrtOp, | ||
| /*FailOnUnsupportedFP*/ true> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Some comments say /*FailOnUnsupportedFP*/ true and some /*FailOnUnsupportedFP=*/true. (I'd recommend the latter.)
matthias-springer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention in the commit message that this fixes an invalid lowering.
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to `true` for all `math-to-llvm` patterns. This fixes various invalid lowerings of `math` ops on `fp8`/`fp4` types.
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to `true` for all `math-to-llvm` patterns. This fixes various invalid lowerings of `math` ops on `fp8`/`fp4` types.
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to `true` for all `math-to-llvm` patterns. This fixes various invalid lowerings of `math` ops on `fp8`/`fp4` types.
Enable
FailOnUnsupportedFPforConvertToLLVMPatternand set it totruefor allmath-to-llvmpatterns. This fixes various invalid lowerings ofmathops onfp8/fp4types.