-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Extend floating point parsing support #90442
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: None (orbiri) ChangesParsing support for floating point types was missing a few features:
This commit addresses both these points. It extends Full diff: https://github.com/llvm/llvm-project/pull/90442.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..fa435cb3155ed4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -700,6 +700,10 @@ class AsmParser {
/// Parse a floating point value from the stream.
virtual ParseResult parseFloat(double &result) = 0;
+ /// Parse a floating point value into APFloat from the stream.
+ virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
+ APFloat &result) = 0;
+
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 30c0079cda0861..8b88a3a6650a3e 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -269,8 +269,11 @@ class AsmParserImpl : public BaseT {
return success();
}
- /// Parse a floating point value from the stream.
- ParseResult parseFloat(double &result) override {
+ /// Parse a floating point value with given semantics from the stream. Since
+ /// this implementation parses the string as double precision and just than
+ /// converts the value to the requested semantic, precision may be lost.
+ ParseResult parseFloat(const llvm::fltSemantics &semantics,
+ APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
SMLoc loc = curTok.getLoc();
@@ -281,7 +284,9 @@ class AsmParserImpl : public BaseT {
if (!val)
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
- result = isNegative ? -*val : *val;
+ result = APFloat(isNegative ? -*val : *val);
+ bool losesInfo;
+ result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
return success();
}
@@ -289,18 +294,28 @@ class AsmParserImpl : public BaseT {
if (curTok.is(Token::integer)) {
std::optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
- apResult, curTok, isNegative, APFloat::IEEEdouble(),
- /*typeSizeInBits=*/64)))
+ apResult, curTok, isNegative, semantics,
+ APFloat::semanticsSizeInBits(semantics))))
return failure();
+ result = *apResult;
parser.consumeToken(Token::integer);
- result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ llvm::APFloat apResult(0.0);
+ if (parseFloat(APFloat::IEEEdouble(), apResult))
+ return failure();
+
+ result = apResult.convertToDouble();
+ return success();
+ }
+
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..1b8b4bac1821e9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -326,19 +326,15 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
"leading minus");
}
- std::optional<uint64_t> value = tok.getUInt64IntegerValue();
- if (!value)
+ APInt intValue;
+ tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+ if (intValue.getActiveBits() > typeSizeInBits)
return emitError(loc, "hexadecimal float constant out of range for type");
- if (&semantics == &APFloat::IEEEdouble()) {
- result = APFloat(semantics, APInt(typeSizeInBits, *value));
- return success();
- }
+ APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+ intValue.getRawData());
- APInt apInt(typeSizeInBits, *value);
- if (apInt != *value)
- return emitError(loc, "hexadecimal float constant out of range for type");
- result = APFloat(semantics, apInt);
+ result.emplace(semantics, truncatedValue);
return success();
}
diff --git a/mlir/test/IR/custom-float-attr-roundtrip.mlir b/mlir/test/IR/custom-float-attr-roundtrip.mlir
new file mode 100644
index 00000000000000..e0913e58d29538
--- /dev/null
+++ b/mlir/test/IR/custom-float-attr-roundtrip.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_enum_attr_roundtrip
+func.func @test_enum_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_float<"float" : 2.000000e+00>
+ "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"double" : 2.000000e+00>
+ "test.op"() {attr =#test.custom_float<"double" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"fp80" : 2.000000e+00>
+ "test.op"() {attr =#test.custom_float<"fp80" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"float" : 0x7FC00000>
+ "test.op"() {attr =#test.custom_float<"float" : 0x7FC00000>} : () -> ()
+ // CHECK: attr = #test.custom_float<"double" : 0x7FF0000001000000>
+ "test.op"() {attr =#test.custom_float<"double" : 0x7FF0000001000000>} : () -> ()
+ // CHECK: attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>
+ "test.op"() {attr =#test.custom_float<"fp80" : 0x7FFFC000000000100000>} : () -> ()
+ return
+}
+
+// -----
+
+// Verify literal must be hex or float
+
+// expected-error @below {{unexpected decimal integer literal for a floating point value}}
+// expected-note @below {{add a trailing dot to make the literal a float}}
+"test.op"() {attr =#test.custom_float<"float" : 42>} : () -> ()
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"float" : 0x7FC000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"double" : 0x7FC000007FC0000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"fp80" : 0x7FC0000007FC0000007FC000000>} : () -> ()
+
+// -----
+
+// Value must be a floating point literal or integer literal
+
+// expected-error @below {{expected floating point literal}}
+"test.op"() {attr =#test.custom_float<"float" : "blabla">} : () -> ()
+
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bebbb876391d07..020942e7f4c11b 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1105,6 +1105,30 @@ func.func @bfloat16_special_values() {
return
}
+// CHECK-LABEL: @f80_special_values
+func.func @f80_special_values() {
+ // F80 signaling NaNs.
+ // CHECK: arith.constant 0x7FFFE000000000000001 : f80
+ %0 = arith.constant 0x7FFFE000000000000001 : f80
+ // CHECK: arith.constant 0x7FFFB000000000000011 : f80
+ %1 = arith.constant 0x7FFFB000000000000011 : f80
+
+ // F80 quiet NaNs.
+ // CHECK: arith.constant 0x7FFFC000000000100000 : f80
+ %2 = arith.constant 0x7FFFC000000000100000 : f80
+ // CHECK: arith.constant 0x7FFFE000000001000000 : f80
+ %3 = arith.constant 0x7FFFE000000001000000 : f80
+
+ // F80 positive infinity.
+ // CHECK: arith.constant 0x7FFF8000000000000000 : f80
+ %4 = arith.constant 0x7FFF8000000000000000 : f80
+ // F80 negative infinity.
+ // CHECK: arith.constant 0xFFFF8000000000000000 : f80
+ %5 = arith.constant 0xFFFF8000000000000000 : f80
+
+ return
+}
+
// We want to print floats in exponential notation with 6 significant digits,
// but it may lead to precision loss when parsing back, in which case we print
// the decimal form instead.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..12635e107bd42c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,15 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
}];
}
+// Test AsmParser::parseFloat(const fltSemnatics&, APFloat&) API through the
+// custom parser and printer.
+def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
+ let mnemonic = "custom_float";
+ let parameters = (ins "mlir::StringAttr":$type_str, APFloatParameter<"">:$value);
+
+ let assemblyFormat = [{
+ `<` custom<CustomFloatAttr>($type_str, $value) `>`
+ }];
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2cc051e664beec..d7e40d35238d91 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -240,6 +241,46 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
p.printKeywordOrString(value);
}
+//===----------------------------------------------------------------------===//
+// Custom Float Attribute
+//===----------------------------------------------------------------------===//
+
+static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
+ APFloat value) {
+ p << typeStrAttr << " : " << value;
+}
+
+static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
+ FailureOr<APFloat> &value) {
+
+ std::string str;
+ if (p.parseString(&str))
+ return failure();
+
+ typeStrAttr = StringAttr::get(p.getContext(), str);
+
+ if (p.parseColon())
+ return failure();
+
+ const llvm::fltSemantics *semantics;
+ if (str == "float")
+ semantics = &llvm::APFloat::IEEEsingle();
+ else if (str == "double")
+ semantics = &llvm::APFloat::IEEEdouble();
+ else if (str == "fp80")
+ semantics = &llvm::APFloat::x87DoubleExtended();
+ else
+ return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
+ "'float', 'double' or 'fp80'");
+
+ APFloat parsedValue(0.0);
+ if (p.parseFloat(*semantics, parsedValue))
+ return failure();
+
+ value.emplace(parsedValue);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
|
46f9583
to
e19b143
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a reasonable change. LGTM % a few nit comments.
As this is core infra, I would suggest to wait for a while such that all the relevant people have the opportunity to take a look.
Parsing support for floating point types was missing a few features: 1. Parsing floating point attributes from integer literals was supported only for types with bitwidth smaller or equal to 64. 2. Downstream users could not use `AsmParser::parseFloat` to parse float types which are printed as integer literals. This commit addresses both these points. It extends `Parser::parseFloatFromIntegerLiteral` to support arbitrary bitwidth, and exposes a new API to parse arbitrary floating point given an fltSemantics as input. The usage of this new API is introduced in the Test Dialect.
e19b143
to
7593830
Compare
Fixed nits. If there are no other objections - I'd appreciate help with committing this to main 🙏 |
@orbiri Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
Parsing support for floating point types was missing a few features:
AsmParser::parseFloat
to parse float types which are printed as integer literals.This commit addresses both these points. It extends
Parser::parseFloatFromIntegerLiteral
to support arbitrary bitwidth, and exposes a new API to parse arbitrary floating point given an fltSemantics as input. The usage of this new API is introduced in the Test Dialect.