-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][math] Add FP software implementation lowering pass: math-to-apfloat #171221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
49f813a to
37ef4b2
Compare
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
debb4a9 to
b0028ca
Compare
2e4e026 to
7533d56
Compare
d9e35f0 to
9d203b2
Compare
6d89bc0 to
8e4a9ac
Compare
8e4a9ac to
8cba26c
Compare
|
@llvm/pr-subscribers-mlir-math Author: Maksim Levental (makslevental) ChangesPatch is 42.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171221.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {});
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr,
- Type resultType = {}) {
- if (!resultType)
- resultType = IntegerType::get(symTable->getContext(), 64);
- std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
- symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(b, loc, b.getI32Type(),
- b.getIntegerAttr(b.getI32Type(), sem));
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ return lookupOrCreateFnDecl(b, symTable, funcName,
+ {i32Type, i64Type, i64Type}, symbolTables);
}
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
- rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ FailureOr<FuncOp> fn =
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+ {i32Type, i32Type, i64Type});
if (failed(fn))
return fn;
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
- {i32Type, i32Type, i1Type, i64Type});
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
- {i32Type, i32Type, i1Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "compare",
- {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..bad8226ac88ec
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ )
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..9cd5a41daf7d8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,210 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue...
[truncated]
|
|
@llvm/pr-subscribers-mlir-execution-engine Author: Maksim Levental (makslevental) ChangesPatch is 42.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171221.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {});
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr,
- Type resultType = {}) {
- if (!resultType)
- resultType = IntegerType::get(symTable->getContext(), 64);
- std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
- symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(b, loc, b.getI32Type(),
- b.getIntegerAttr(b.getI32Type(), sem));
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ return lookupOrCreateFnDecl(b, symTable, funcName,
+ {i32Type, i64Type, i64Type}, symbolTables);
}
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
- rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ FailureOr<FuncOp> fn =
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+ {i32Type, i32Type, i64Type});
if (failed(fn))
return fn;
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
- {i32Type, i32Type, i1Type, i64Type});
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
- {i32Type, i32Type, i1Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "compare",
- {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..bad8226ac88ec
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ )
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..9cd5a41daf7d8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,210 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue...
[truncated]
|
|
@llvm/pr-subscribers-mlir-func Author: Maksim Levental (makslevental) ChangesPatch is 42.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171221.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {});
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr,
- Type resultType = {}) {
- if (!resultType)
- resultType = IntegerType::get(symTable->getContext(), 64);
- std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
- symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(b, loc, b.getI32Type(),
- b.getIntegerAttr(b.getI32Type(), sem));
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ return lookupOrCreateFnDecl(b, symTable, funcName,
+ {i32Type, i64Type, i64Type}, symbolTables);
}
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
- rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ FailureOr<FuncOp> fn =
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+ {i32Type, i32Type, i64Type});
if (failed(fn))
return fn;
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
- {i32Type, i32Type, i1Type, i64Type});
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
- {i32Type, i32Type, i1Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "compare",
- {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..bad8226ac88ec
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ )
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..9cd5a41daf7d8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,210 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesPatch is 42.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171221.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
new file mode 100644
index 0000000000000..86179a1611d5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
@@ -0,0 +1,21 @@
+//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 82bdfd02661a6..05ec2f8ce2538 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fcbaf3ccc1486..7f24e58671aab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToAPFloat
+//===----------------------------------------------------------------------===//
+
+def MathToAPFloatConversionPass
+ : Pass<"convert-math-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Math ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Math ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point mathmetic operations.
+ }];
+ let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 00d50874a2e8d..079c1f461b6ed 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);
+/// Create a FuncOp decl and insert it into `symTable` operation. If
+/// `symbolTables` is provided, then the decl will be inserted into the
+/// SymbolTableCollection.
+FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr);
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+FailureOr<FuncOp>
+lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {});
+
} // namespace func
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
similarity index 88%
rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
index 79816fc6e3bf1..b9ba94ef08098 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;
-static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, FunctionType funcT, bool setPrivate,
- SymbolTableCollection *symbolTables = nullptr) {
- OpBuilder::InsertionGuard g(b);
- assert(!symTable->getRegion(0).empty() && "expected non-empty region");
- b.setInsertionPointToStart(&symTable->getRegion(0).front());
- FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
- if (setPrivate)
- funcOp.setPrivate();
- if (symbolTables) {
- SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
- symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
- }
- return funcOp;
-}
-
-/// Helper function to look up or create the symbol for a runtime library
-/// function with the given parameter types. Returns an int64_t, unless a
-/// different result type is specified.
-static FailureOr<FuncOp>
-lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
- StringRef name, TypeRange paramTypes,
- SymbolTableCollection *symbolTables = nullptr,
- Type resultType = {}) {
- if (!resultType)
- resultType = IntegerType::get(symTable->getContext(), 64);
- std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
- auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
- FailureOr<FuncOp> func =
- lookupFnDecl(symTable, funcName, funcT, symbolTables);
- // Failed due to type mismatch.
- if (failed(func))
- return func;
- // Successfully matched existing decl.
- if (*func)
- return *func;
-
- return createFnDecl(b, symTable, funcName, funcT,
- /*setPrivate=*/true, symbolTables);
-}
-
/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
- symbolTables);
-}
-
-static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
- int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
- return arith::ConstantOp::create(b, loc, b.getI32Type(),
- b.getIntegerAttr(b.getI32Type(), sem));
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ return lookupOrCreateFnDecl(b, symTable, funcName,
+ {i32Type, i64Type, i64Type}, symbolTables);
}
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
- rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ FailureOr<FuncOp> fn =
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
+ {i32Type, i32Type, i64Type});
if (failed(fn))
return fn;
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
- {i32Type, i32Type, i1Type, i64Type});
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
// Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
- {i32Type, i32Type, i1Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value outSemValue =
+ getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "compare",
- {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
- FailureOr<FuncOp> fn =
- lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
// Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..bad8226ac88ec
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
@@ -0,0 +1,49 @@
+add_mlir_library(ArithAndMathToAPFloatUtils
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ )
+
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
+
+add_mlir_conversion_library(MLIRMathToAPFloat
+ MathToAPFloat.cpp
+ PARTIAL_SOURCES_INTENDED
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ ArithAndMathToAPFloatUtils
+ MLIRMathDialect
+ MLIRFuncDialect
+ MLIRFuncUtils
+ )
diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
new file mode 100644
index 0000000000000..9cd5a41daf7d8
--- /dev/null
+++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp
@@ -0,0 +1,210 @@
+//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
+#include "Utils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> {
+ AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(math::AbsFOp op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Get APFloat function from runtime library.
+ auto i1 = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName =
+ (llvm::Twine("_mlir_apfloat_is") + APFloatName).str();
+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
+ rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1);
+ if (failed(fn))
+ return fn;
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ // Cast operands to 64-bit integers.
+ auto operand = op.getOperand();
+ auto floatTy = cast<FloatType>(operand.getType());
+ if (floatTy.getIntOrFloatBitWidth() > 64) {
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand));
+
+ // Call APFloat function.
+ Value semValue...
[truncated]
|
| auto operand = op.getOperand(); | ||
| auto floatTy = cast<FloatType>(operand.getType()); | ||
| if (floatTy.getIntOrFloatBitWidth() > 64) { | ||
| return rewriter.notifyMatchFailure(op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"return failure/notifyMatchFailure()" should be before any kind of IR modification. (lookupOrCreateFnDecl modifies the IR.)
| rewriter.setInsertionPoint(op); | ||
| // Cast operands to 64-bit integers. | ||
| auto operand = op.getOperand(); | ||
| auto floatTy = cast<FloatType>(operand.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operand type could be a vector. I added support to ArithToAPFloat recently to support vectors. Maybe we can follow the same design here. (This PR is already large enough, I would just return failure for now if it's not a float type.)
| // RUN: --shared-libs=%mlir_c_runner_utils \ | ||
| // RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s --check-prefix=CHECK-FP8 | ||
|
|
||
| // Case 2: Only unsupported arithmetics is lowered through APFloat. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more interesting to also lower f32 to APFloat emulation. The reason why I had Case 1 and Case 2 in the other test case is to show that the order of the two passes --convert-math-to-apfloat --convert-to-llvm matters. We already did that in the other test-apfloat-emulation integration test. We don't get any extra test coverage by lowering to LLVM.
I'll leave it up to you, but I would drop Case 2 and use a single FileCheck and entry func for both f8 and f32.
1d77e12 to
8e33e31
Compare
8e33e31 to
e368626
Compare
|
should this directory be renamed as well? |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/23270 Here is the relevant piece of the build log for the reference |
…th-to-apfloat" (#172714) Reverts #171221 Broken builder https://lab.llvm.org/buildbot/#/builders/138/builds/23270
I had to revert but sure I can do it in the reapply #172716 |
…ng pass: math-to-apfloat" (#172714) Reverts llvm/llvm-project#171221 Broken builder https://lab.llvm.org/buildbot/#/builders/138/builds/23270
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/30917 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/32128 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/30938 Here is the relevant piece of the build log for the reference |
…ing pass: math-to-apfloat" (#172714) (#172716) Reapply llvm/llvm-project#171221 - Fix builder by linking `MLIRTransformUtils`. Also move headers to `mlir/Conversion/ArithAndMathToAPFloat`.
Integrate llvm/llvm-project@5f15fee Local revert of llvm/llvm-project#169614 due to suspiscion of link to numerical issues observed in #22649. * Also had to revert iree-org/llvm-project@dea9ec8 which had landed on top. Cherry pick of Bazel fix: llvm/llvm-project@a341180 * Could not just advance to that as this crosses another Bazel-regressing commit, llvm/llvm-project#171221 Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Add APFloat software implementation for
math.fma,math.abs,math.isnan,math.isfinite,math.isinf,math.isnormalfor reduced precision (fp4*,fp6*,fp8*).