-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir] Add FP software implementation lowering pass: arith-to-apfloat
#166618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
622e5fc to
c42f471
Compare
805e5fc to
78df4a8
Compare
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesLower floating point arithmetics that are unsupported by LLVM to calls into the execution engine runtime library. Patch is 23.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166618.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
new file mode 100644
index 0000000000000..64a42a228199e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -0,0 +1,21 @@
+//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..82bdfd02661a6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..2cd7d14f5517b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,6 +186,22 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToAPFloat
+//===----------------------------------------------------------------------===//
+
+def ArithToAPFloatConversionPass
+ : Pass<"convert-arith-to-apfloat", "ModuleOp"> {
+ let summary = "Convert Arith ops to APFloat runtime library calls";
+ let description = [{
+ This pass converts supported Arith ops to APFloat-based runtime library
+ calls (APFloatWrappers.cpp). APFloat is a software implementation of
+ floating-point arithmetic operations.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+ let options = [];
+}
+
//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 3576126a487ac..9c9973cf84368 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);
+/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<FuncOp>
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes = {},
+ ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
+ SymbolTableCollection *symbolTables = nullptr);
+
} // namespace func
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8ad9ed18acebd..b09d32022e348 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000000000..62074625033ce
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,113 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/Verifier.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FailureOr<Operation *>
+lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ return lookupOrCreateFn(b, moduleOp,
+ (llvm::Twine("_mlir_apfloat_") + name).str(),
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ {IntegerType::get(moduleOp->getContext(), 64)},
+ /*setPrivate*/ true, symbolTables);
+}
+
+template <typename OpTy>
+static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
+ OpTy op, StringRef apfloatName) {
+ // Get APFloat function from runtime library.
+ FailureOr<Operation *> fn =
+ lookupOrCreateBinaryFn(rewriter, module, apfloatName);
+ if (failed(fn))
+ return op->emitError("failed to lookup or create APFloat function");
+
+ // Cast operands to 64-bit integers.
+ Location loc = op.getLoc();
+ auto floatTy = cast<FloatType>(op.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+ // Call APFloat function.
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
+ rewriter.replaceOp(
+ op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ return success();
+}
+
+namespace {
+struct ArithToAPFloatConversionPass final
+ : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+ using impl::ArithToAPFloatConversionPassBase<
+ ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+ IRRewriter rewriter(getOperation()->getContext());
+ SmallVector<arith::AddFOp> addOps;
+ WalkResult status = module->walk([&](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ LogicalResult result =
+ llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<arith::AddFOp>([&](arith::AddFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "add");
+ })
+ .Case<arith::SubFOp>([&](arith::SubFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "subtract");
+ })
+ .Case<arith::MulFOp>([&](arith::MulFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "mulitply");
+ })
+ .Case<arith::DivFOp>([&](arith::DivFOp op) {
+ return rewriteBinaryOp(rewriter, module, op, "divide");
+ })
+ .Default([](Operation *op) { return success(); });
+ if (failed(result))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (status.wasInterrupted())
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..b0d1e46b3655f
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b6099902cc337..f2bacc3399144 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..613dc6d242ceb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ecd101f..c747e1b59558a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb0932ef631..e187e62cf6555 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
+ bool setPrivate, SymbolTableCollection *symbolTables) {
+ assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+ "expected SymbolTable operation");
+
+ FuncOp func;
+ if (symbolTables) {
+ func = symbolTables->lookupSymbolIn<FuncOp>(
+ moduleOp, StringAttr::get(moduleOp->getContext(), name));
+ } else {
+ func = llvm::dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
+ }
+
+ FunctionType funcT =
+ FunctionType::get(b.getContext(), paramTypes, resultTypes);
+ // Assert the signature of the found function is same as expected
+ if (func) {
+ if (funcT != func.getFunctionType()) {
+ func.emitError("redefinition of function '")
+ << name << "' of different type " << funcT << " is prohibited";
+ return failure();
+ }
+ return func;
+ }
+
+ OpBuilder::InsertionGuard g(b);
+ assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+ }
+
+ return funcOp;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index feaffa34897b6..160b6ae89215c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kPrintApFloat,
+ {IntegerType::get(moduleOp->getContext(), 32),
+ IntegerType::get(moduleOp->getContext(), 64)},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
+}
+
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
new file mode 100644
index 0000000000000..e8f57f231ce5c
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -0,0 +1,71 @@
+//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/APFloat.h"
+
+#include <iostream>
+
+#if (defined(_WIN32) || defined(__CYGWIN__))
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
+#else
+#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
+#endif
+
+/// Binary operations without rounding mode.
+#define APFLOAT_BINARY_OP(OP) \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ llvm::APFloatBase::opStatus status = lhs.OP(rhs); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+/// Binary operations with rounding mode.
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
+ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
+ int32_t semantics, uint64_t a, uint64_t b) { \
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
+ static_cast<llvm::APFloatBase::Semantics>(semantics)); \
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
+ llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
+ llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
+ llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \
+ return lhs.bitcastToAPInt().getZExtValue(); \
+ }
+
+extern "C" {
+
+#define BIN_OPS_WITH_ROUNDING(X) \
+ X(add, llvm::RoundingMode::NearestTiesToEven) \
+ X(subtract, llvm::RoundingMode::NearestTiesToEven) \
+ X(multiply, llvm::RoundingMode::NearestTiesToEven) \
+ X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+APFLOAT_BINARY_OP(mod)
+
+#undef APFLOAT_BINARY_OP
+
+void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
+ uint64_t a) {
+ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+ static_cast<llvm::APFloatBase::Semantics>(semantics));
+ unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+ llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+ double d = x.convertToDouble();
+ fprintf(stdout, "%lg", d);
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdeb4dacf9278..8c09e50e4de7b 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
# is a big dependency which most don't need.
set(LLVM_OPTIONAL_SOURCES
+ APFloatWrappers.cpp
ArmRunnerUtils.cpp
ArmSMEStubs.cpp
AsyncRuntime.cpp
@@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC)
set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17)
target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS)
+ add_mlir_library(mlir_apfloat_wrappers
+ SHARED
+ APFloatWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+ set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17)
+ target_compile_definitions(mlir_apfloat...
[truncated]
|
de9b2af to
1180064
Compare
| walkAndApplyPatterns(getOperation(), std::move(patterns)); | ||
| if (failed(result)) | ||
| return signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthias-springer this is what we want right? if there's a failure due to overlapping decls we want to fail the pass right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable to me.
|
cc @nikalra |
| ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { | ||
| if (diag.getSeverity() == DiagnosticSeverity::Error) { | ||
| result = failure(); | ||
| } | ||
| // NB: if you don't return failure, no other diag handlers will fire (see | ||
| // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). | ||
| return failure(); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is something the walk rewriter should do? (separately from this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My 2 cents: leaving it decomposed is better - this is not hard to write (once you know how diaghandlers work 😜).
But it's up to you - you're code owner/designer of that. If you so wish it. I can refactor this into the rewriter. Maybe the "abstraction" could be passing just the callback.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/21601 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/33850 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/20791 Here is the relevant piece of the build log for the reference |
|
reverting here #167429 |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/14944 Here is the relevant piece of the build log for the reference |
…ss: `arith-to-apfloat` (#166618)" (#167431) Reland llvm/llvm-project#166618 with MLIRFuncUtils linked in.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/28691 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/27482 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/16926 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/27503 Here is the relevant piece of the build log for the reference |
This commit adds a new pass that lowers floating-point
arithoperations to calls into the execution engine runtime library. Currently supported operations:addf,subf,mulf,divf,remf.All floating-point types that have an APFloat semantics are supported. This includes low-precision floating-point types such as
f4E2M1FNthat cannot execute natively on CPUs.This commit also improves the
vector.printlowering pattern to call into the runtime library for floating-point types that are not supported by LLVM. This is necessary to write a meaningful integration test.The way it works is
gets transformed to
Note,
llvm::fltSemantics(f8E4M3FN)is emitted by the pattern each time anarithop is transformed, thereby making the call to__mlir_apfloat_addcorrect (i.e., no name mangling on type necessary).RFC: https://discourse.llvm.org/t/rfc-software-implementation-for-unsupported-fp-types-in-convert-arith-to-llvm/88785