Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Extend floating point parsing support #90442

Merged
merged 1 commit into from
May 4, 2024

Conversation

orbiri
Copy link
Contributor

@orbiri orbiri commented Apr 29, 2024

Parsing support for floating point types was missing a few features:

  1. Parsing floating point attributes from integer literals was supported only for types with bitwidth smaller or equal to 64.
  2. Downstream users could not use AsmParser::parseFloat to parse float types which are printed as integer literals.

This commit addresses both these points. It extends Parser::parseFloatFromIntegerLiteral to support arbitrary bitwidth, and exposes a new API to parse arbitrary floating point given an fltSemantics as input. The usage of this new API is introduced in the Test Dialect.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Apr 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: None (orbiri)

Changes

Parsing support for floating point types was missing a few features:

  1. Parsing floating point attributes from integer literals was supported only for types with bitwidth smaller or equal to 64.
  2. Downstream users could not use AsmParser::parseFloat to parse float types which are printed as integer literals.

This commit addresses both these points. It extends Parser::parseFloatFromIntegerLiteral to support arbitrary bitwidth, and exposes a new API to parse arbitrary floating point given an fltSemantics as input. The usage of this new API is introduced in the Test Dialect.


Full diff: https://github.com/llvm/llvm-project/pull/90442.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/OpImplementation.h (+4)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+21-6)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+6-10)
  • (added) mlir/test/IR/custom-float-attr-roundtrip.mlir (+57)
  • (modified) mlir/test/IR/parser.mlir (+24)
  • (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (+11)
  • (modified) mlir/test/lib/Dialect/Test/TestAttributes.cpp (+41)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..fa435cb3155ed4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -700,6 +700,10 @@ class AsmParser {
   /// Parse a floating point value from the stream.
   virtual ParseResult parseFloat(double &result) = 0;
 
+  /// Parse a floating point value into APFloat from the stream.
+  virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
+                                 APFloat &result) = 0;
+
   /// Parse an integer value from the stream.
   template <typename IntT>
   ParseResult parseInteger(IntT &result) {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 30c0079cda0861..8b88a3a6650a3e 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -269,8 +269,11 @@ class AsmParserImpl : public BaseT {
     return success();
   }
 
-  /// Parse a floating point value from the stream.
-  ParseResult parseFloat(double &result) override {
+  /// Parse a floating point value with given semantics from the stream. Since
+  /// this implementation parses the string as double precision and just than
+  /// converts the value to the requested semantic, precision may be lost.
+  ParseResult parseFloat(const llvm::fltSemantics &semantics,
+                         APFloat &result) override {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
     SMLoc loc = curTok.getLoc();
@@ -281,7 +284,9 @@ class AsmParserImpl : public BaseT {
       if (!val)
         return emitError(loc, "floating point value too large");
       parser.consumeToken(Token::floatliteral);
-      result = isNegative ? -*val : *val;
+      result = APFloat(isNegative ? -*val : *val);
+      bool losesInfo;
+      result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
       return success();
     }
 
@@ -289,18 +294,28 @@ class AsmParserImpl : public BaseT {
     if (curTok.is(Token::integer)) {
       std::optional<APFloat> apResult;
       if (failed(parser.parseFloatFromIntegerLiteral(
-              apResult, curTok, isNegative, APFloat::IEEEdouble(),
-              /*typeSizeInBits=*/64)))
+              apResult, curTok, isNegative, semantics,
+              APFloat::semanticsSizeInBits(semantics))))
         return failure();
 
+      result = *apResult;
       parser.consumeToken(Token::integer);
-      result = apResult->convertToDouble();
       return success();
     }
 
     return emitError(loc, "expected floating point literal");
   }
 
+  /// Parse a floating point value from the stream.
+  ParseResult parseFloat(double &result) override {
+    llvm::APFloat apResult(0.0);
+    if (parseFloat(APFloat::IEEEdouble(), apResult))
+      return failure();
+
+    result = apResult.convertToDouble();
+    return success();
+  }
+
   /// Parse an optional integer value from the stream.
   OptionalParseResult parseOptionalInteger(APInt &result) override {
     return parser.parseOptionalInteger(result);
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..1b8b4bac1821e9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -326,19 +326,15 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
                           "leading minus");
   }
 
-  std::optional<uint64_t> value = tok.getUInt64IntegerValue();
-  if (!value)
+  APInt intValue;
+  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+  if (intValue.getActiveBits() > typeSizeInBits)
     return emitError(loc, "hexadecimal float constant out of range for type");
 
-  if (&semantics == &APFloat::IEEEdouble()) {
-    result = APFloat(semantics, APInt(typeSizeInBits, *value));
-    return success();
-  }
+  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+                       intValue.getRawData());
 
-  APInt apInt(typeSizeInBits, *value);
-  if (apInt != *value)
-    return emitError(loc, "hexadecimal float constant out of range for type");
-  result = APFloat(semantics, apInt);
+  result.emplace(semantics, truncatedValue);
 
   return success();
 }
diff --git a/mlir/test/IR/custom-float-attr-roundtrip.mlir b/mlir/test/IR/custom-float-attr-roundtrip.mlir
new file mode 100644
index 00000000000000..e0913e58d29538
--- /dev/null
+++ b/mlir/test/IR/custom-float-attr-roundtrip.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_enum_attr_roundtrip
+func.func @test_enum_attr_roundtrip() -> () {
+  // CHECK: attr = #test.custom_float<"float" : 2.000000e+00>
+  "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> ()
+  // CHECK: attr = #test.custom_float<"double" : 2.000000e+00>
+  "test.op"() {attr =#test.custom_float<"double" : 2.>} : () -> ()
+   // CHECK: attr = #test.custom_float<"fp80" : 2.000000e+00>
+  "test.op"() {attr =#test.custom_float<"fp80" : 2.>} : () -> ()
+  // CHECK: attr = #test.custom_float<"float" : 0x7FC00000>
+  "test.op"() {attr =#test.custom_float<"float" : 0x7FC00000>} : () -> ()
+  // CHECK: attr = #test.custom_float<"double" : 0x7FF0000001000000>
+  "test.op"() {attr =#test.custom_float<"double" : 0x7FF0000001000000>} : () -> ()
+  // CHECK: attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>
+  "test.op"() {attr =#test.custom_float<"fp80" : 0x7FFFC000000000100000>} : () -> ()
+  return
+}
+
+// -----
+
+// Verify literal must be hex or float
+
+// expected-error @below {{unexpected decimal integer literal for a floating point value}}
+// expected-note @below {{add a trailing dot to make the literal a float}}
+"test.op"() {attr =#test.custom_float<"float" : 42>} : () -> ()
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"float" : 0x7FC000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"double" : 0x7FC000007FC0000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr =#test.custom_float<"fp80" : 0x7FC0000007FC0000007FC000000>} : () -> ()
+
+// -----
+
+// Value must be a floating point literal or integer literal
+
+// expected-error @below {{expected floating point literal}}
+"test.op"() {attr =#test.custom_float<"float" : "blabla">} : () -> ()
+
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bebbb876391d07..020942e7f4c11b 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1105,6 +1105,30 @@ func.func @bfloat16_special_values() {
   return
 }
 
+// CHECK-LABEL: @f80_special_values
+func.func @f80_special_values() {
+  // F80 signaling NaNs.
+  // CHECK: arith.constant 0x7FFFE000000000000001 : f80
+  %0 = arith.constant 0x7FFFE000000000000001 : f80
+  // CHECK: arith.constant 0x7FFFB000000000000011 : f80
+  %1 = arith.constant 0x7FFFB000000000000011 : f80
+
+  // F80 quiet NaNs.
+  // CHECK: arith.constant 0x7FFFC000000000100000 : f80
+  %2 = arith.constant 0x7FFFC000000000100000 : f80
+  // CHECK: arith.constant 0x7FFFE000000001000000 : f80
+  %3 = arith.constant 0x7FFFE000000001000000 : f80
+
+  // F80 positive infinity.
+  // CHECK: arith.constant 0x7FFF8000000000000000 : f80
+  %4 = arith.constant 0x7FFF8000000000000000 : f80
+  // F80 negative infinity.
+  // CHECK: arith.constant 0xFFFF8000000000000000 : f80
+  %5 = arith.constant 0xFFFF8000000000000000 : f80
+
+  return
+}
+
 // We want to print floats in exponential notation with 6 significant digits,
 // but it may lead to precision loss when parsing back, in which case we print
 // the decimal form instead.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..12635e107bd42c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,15 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
   }];
 }
 
+// Test AsmParser::parseFloat(const fltSemnatics&, APFloat&) API through the
+// custom parser and printer.
+def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
+  let mnemonic = "custom_float";
+  let parameters = (ins "mlir::StringAttr":$type_str, APFloatParameter<"">:$value);
+
+  let assemblyFormat = [{
+    `<` custom<CustomFloatAttr>($type_str, $value) `>`
+  }];
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2cc051e664beec..d7e40d35238d91 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -240,6 +241,46 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
   p.printKeywordOrString(value);
 }
 
+//===----------------------------------------------------------------------===//
+// Custom Float Attribute
+//===----------------------------------------------------------------------===//
+
+static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
+                                 APFloat value) {
+  p << typeStrAttr << " : " << value;
+}
+
+static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
+                                        FailureOr<APFloat> &value) {
+
+  std::string str;
+  if (p.parseString(&str))
+    return failure();
+
+  typeStrAttr = StringAttr::get(p.getContext(), str);
+
+  if (p.parseColon())
+    return failure();
+
+  const llvm::fltSemantics *semantics;
+  if (str == "float")
+    semantics = &llvm::APFloat::IEEEsingle();
+  else if (str == "double")
+    semantics = &llvm::APFloat::IEEEdouble();
+  else if (str == "fp80")
+    semantics = &llvm::APFloat::x87DoubleExtended();
+  else
+    return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
+                                               "'float', 'double' or 'fp80'");
+
+  APFloat parsedValue(0.0);
+  if (p.parseFloat(*semantics, parsedValue))
+    return failure();
+
+  value.emplace(parsedValue);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a reasonable change. LGTM % a few nit comments.

As this is core infra, I would suggest to wait for a while such that all the relevant people have the opportunity to take a look.

mlir/lib/AsmParser/AsmParserImpl.h Outdated Show resolved Hide resolved
mlir/test/IR/custom-float-attr-roundtrip.mlir Outdated Show resolved Hide resolved
Parsing support for floating point types was missing a few features:
1. Parsing floating point attributes from integer literals was supported only
   for types with bitwidth smaller or equal to 64.
2. Downstream users could not use `AsmParser::parseFloat` to parse float types
   which are printed as integer literals.

This commit addresses both these points. It extends
`Parser::parseFloatFromIntegerLiteral` to support arbitrary bitwidth, and
exposes a new API to parse arbitrary floating point given an fltSemantics as
input. The usage of this new API is introduced in the Test Dialect.
@orbiri
Copy link
Contributor Author

orbiri commented May 4, 2024

Fixed nits. If there are no other objections - I'd appreciate help with committing this to main 🙏

@Dinistro Dinistro merged commit 1e3c630 into llvm:main May 4, 2024
2 of 3 checks passed
Copy link

github-actions bot commented May 4, 2024

@orbiri Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants