From e2e29a50941572c0ba1e10fc466542126f5a5dc4 Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 10 Nov 2025 17:41:17 -0800 Subject: [PATCH 1/4] Reapply "Reapply "[mlir] Add FP software implementation lowering pass: `arith-to-apfloat` (#166618)" (#167431)" This reverts commit 6631fe3fac81f6f333b7491107d97849b7d7ae0e. --- .../ArithToAPFloat/ArithToAPFloat.h | 21 +++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 15 ++ mlir/include/mlir/Dialect/Func/Utils/Utils.h | 7 + .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 + .../ArithToAPFloat/ArithToAPFloat.cpp | 161 ++++++++++++++++++ .../Conversion/ArithToAPFloat/CMakeLists.txt | 18 ++ .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 1 + mlir/lib/Conversion/CMakeLists.txt | 1 + .../VectorToLLVM/ConvertVectorToLLVM.cpp | 14 ++ mlir/lib/Dialect/Func/Utils/Utils.cpp | 25 +++ .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 11 ++ mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 81 +++++++++ mlir/lib/ExecutionEngine/CMakeLists.txt | 12 ++ .../ArithToApfloat/arith-to-apfloat.mlir | 128 ++++++++++++++ .../Arith/CPU/test-apfloat-emulation.mlir | 34 ++++ 16 files changed, 534 insertions(+) create mode 100644 mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h create mode 100644 mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp create mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt create mode 100644 mlir/lib/ExecutionEngine/APFloatWrappers.cpp create mode 100644 mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir create mode 100644 mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h new file mode 100644 index 0000000000000..64a42a228199e --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h @@ -0,0 +1,21 @@ +//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===// +// +// Part of the APFloat Project, under the Apache License v2.0 with APFloat +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 40d866ec7bf10..82bdfd02661a6 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 70e3e45c225db..79bc380dbcb7a 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// ArithToAPFloat +//===----------------------------------------------------------------------===// + +def ArithToAPFloatConversionPass + : Pass<"convert-arith-to-apfloat", "ModuleOp"> { + let summary = "Convert Arith ops to APFloat runtime library calls"; + let description = [{ + This pass converts supported Arith ops to APFloat-based runtime library + calls (APFloatWrappers.cpp). APFloat is a software implementation of + floating-point arithmetic operations. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + //===----------------------------------------------------------------------===// // ArithToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 3576126a487ac..00d50874a2e8d 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -60,6 +60,13 @@ mlir::FailureOr> deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp); +/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name +/// `name`. Return a failure if the FuncOp is found but with a different +/// signature. +FailureOr lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, + SymbolTableCollection *symbolTables = nullptr); + } // namespace func } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 8ad9ed18acebd..b09d32022e348 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, FailureOr lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); + /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000000000..012e934d3050f --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,161 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith-to-apfloat" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +/// Helper function to look up or create the symbol for a runtime library +/// function for a binary arithmetic operation. +/// +/// Parameter 1: APFloat semantics +/// Parameter 2: Left-hand side operand +/// Parameter 3: Right-hand side operand +/// +/// This function will return a failure if the function is found but has an +/// unexpected signature. +/// +static FailureOr +lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + SymbolTableCollection *symbolTables = nullptr) { + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + + std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str(); + FunctionType funcT = + FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type}); + FailureOr func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + +/// Rewrite a binary arithmetic operation to an APFloat function call. +template +struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { + BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit, + SymbolOpInterface symTable) + : OpRewritePattern(context, benefit), symTable(symTable) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + FailureOr fn = + lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); + if (failed(fn)) + return fn; + + rewriter.setInsertionPoint(op); + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + auto floatTy = cast(op.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call APFloat function. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + SmallVector params = {semValue, lhsBits, rhsBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + rewriter.replaceOp( + op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); + return success(); + } + + SymbolOpInterface symTable; +}; + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase { + using Base::Base; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + static const char add[] = "add"; + static const char subtract[] = "subtract"; + static const char multiply[] = "multiply"; + static const char divide[] = "divide"; + static const char remainder[] = "remainder"; + patterns.add, + BinaryArithOpToAPFloatConversion, + BinaryArithOpToAPFloatConversion, + BinaryArithOpToAPFloatConversion, + BinaryArithOpToAPFloatConversion>( + context, 1, getOperation()); + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000000000..b5ec49c087163 --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index b6099902cc337..f2bacc3399144 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8fff3f9..613dc6d242ceb 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ecd101f..c747e1b59558a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { return failure(); } } + } else if (auto floatTy = dyn_cast(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb0932ef631..d6dfd0229963c 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr +func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, SymbolTableCollection *symbolTables) { + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn( + symTable, StringAttr::get(symTable->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symTable, name)); + } + + if (!func) + return func; + + mlir::FunctionType foundFuncT = func.getFunctionType(); + // Assert the signature of the found function is same as expected + if (funcT != foundFuncT) { + return func.emitError("matched function '") + << name << "' but with different type: " << foundFuncT + << " (expected " << funcT << ")"; + } + return func; +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa34897b6..160b6ae89215c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp new file mode 100644 index 0000000000000..85ea0986cde5b --- /dev/null +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -0,0 +1,81 @@ +//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the APFloat infrastructure to MLIR programs as a runtime +// library. APFloat is a software implementation of floating point arithmetics. +// +// On the MLIR side, floating-point values must be bitcasted to 64-bit integers +// before calling a runtime function. If a floating-point type has less than +// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an +// integer. +// +// Runtime functions receive the floating-point operands of the arithmeic +// operation in the form of 64-bit integers, along with the APFloat semantics +// in the form of a 32-bit integer, which will be interpreted as an +// APFloatBase::Semantics enum value. +// +#include "llvm/ADT/APFloat.h" + +#if (defined(_WIN32) || defined(__CYGWIN__)) +#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport) +#else +#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default"))) +#endif + +/// Binary operations without rounding mode. +#define APFLOAT_BINARY_OP(OP) \ + int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +/// Binary operations with rounding mode. +#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ + int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs, ROUNDING_MODE); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +extern "C" { + +#define BIN_OPS_WITH_ROUNDING(X) \ + X(add, llvm::RoundingMode::NearestTiesToEven) \ + X(subtract, llvm::RoundingMode::NearestTiesToEven) \ + X(multiply, llvm::RoundingMode::NearestTiesToEven) \ + X(divide, llvm::RoundingMode::NearestTiesToEven) + +BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE) +#undef BIN_OPS_WITH_ROUNDING +#undef APFLOAT_BINARY_OP_ROUNDING_MODE + +APFLOAT_BINARY_OP(remainder) + +#undef APFLOAT_BINARY_OP + +void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + double d = x.convertToDouble(); + fprintf(stdout, "%lg", d); +} +} diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fdeb4dacf9278..8c09e50e4de7b 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + APFloatWrappers.cpp ArmRunnerUtils.cpp ArmSMEStubs.cpp AsyncRuntime.cpp @@ -167,6 +168,15 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS) + add_mlir_library(mlir_apfloat_wrappers + SHARED + APFloatWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + set_property(TARGET mlir_apfloat_wrappers PROPERTY CXX_STANDARD 17) + target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS) + add_subdirectory(SparseTensor) add_mlir_library(mlir_c_runner_utils @@ -177,6 +187,7 @@ if(LLVM_ENABLE_PIC) EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + mlir_apfloat_wrappers mlir_float16_utils MLIRSparseTensorEnums MLIRSparseTensorRuntime @@ -191,6 +202,7 @@ if(LLVM_ENABLE_PIC) EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + mlir_apfloat_wrappers mlir_float16_utils ) target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS) diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir new file mode 100644 index 0000000000000..fe4d28a56f808 --- /dev/null +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -0,0 +1,128 @@ +// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64 + +// CHECK-LABEL: func.func @foo() -> f8E4M3FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN +// CHECK: return %[[CONSTANT_0]] : f8E4M3FN +// CHECK: } + +// CHECK-LABEL: func.func @bar() -> f6E3M2FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN +// CHECK: return %[[CONSTANT_0]] : f6E3M2FN +// CHECK: } + +// Illustrate that both f8E4M3FN and f6E3M2FN calling the same __mlir_apfloat_add is fine +// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width. +// CHECK-LABEL: func.func @full_example() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN +// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64 +// // fltSemantics semantics for f8E4M3FN +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_1:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 +// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN +// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN + +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN +// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN +// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64 +// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64 +// // fltSemantics semantics for f6E3M2FN +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32 +// CHECK: %[[VAL_3:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6 +// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN +// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN +// CHECK: return +// CHECK: } + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> f8E4M3FN { + %cst = arith.constant 2.2 : f8E4M3FN + return %cst : f8E4M3FN +} + +func.func @bar() -> f6E3M2FN { + %cst = arith.constant 3.2 : f6E3M2FN + return %cst : f6E3M2FN +} + +func.func @full_example() { + %a = arith.constant 1.4 : f8E4M3FN + %b = func.call @foo() : () -> (f8E4M3FN) + %c = arith.addf %a, %b : f8E4M3FN + vector.print %c : f8E4M3FN + + %d = arith.constant 2.4 : f6E3M2FN + %e = func.call @bar() : () -> (f6E3M2FN) + %f = arith.addf %d, %e : f6E3M2FN + vector.print %f : f6E3M2FN + return +} + +// ----- + +// CHECK: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @__mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// Test decl collision (different type) +// expected-error@+1{{matched function '__mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} +func.func private @__mlir_apfloat_add(i32, i32, f32) -> index +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @__mlir_apfloat_subtract(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @__mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.subf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @__mlir_apfloat_multiply(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @__mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.mulf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @__mlir_apfloat_divide(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @__mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.divf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @__mlir_apfloat_remainder(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @__mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.remf %arg0, %arg1 : f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir new file mode 100644 index 0000000000000..a2b3eb73a60b8 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,34 @@ +// Case 1: All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s + +// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat. +// Arithmetics on f32 is lowered directly to LLVM. +// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \ +// RUN: --convert-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> (f8E4M3FN, f32) { + %cst1 = arith.constant 2.2 : f8E4M3FN + %cst2 = arith.constant 2.2 : f32 + return %cst1, %cst2 : f8E4M3FN, f32 +} + +func.func @entry() { + %a1 = arith.constant 1.4 : f8E4M3FN + %a2 = arith.constant 1.4 : f32 + %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32) + %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM + %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM + + // CHECK: 3.5 + vector.print %c1 : f8E4M3FN + + // CHECK: 3.6 + vector.print %c2 : f32 + + return +} From 77a05230282b7b29251c67eb7ee9d089223d1d20 Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 10 Nov 2025 17:46:14 -0800 Subject: [PATCH 2/4] put symbols first --- mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 85ea0986cde5b..cbb9b35e6dc9a 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -29,7 +29,7 @@ /// Binary operations without rounding mode. #define APFLOAT_BINARY_OP(OP) \ - int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \ + MLIR_APFLOAT_WRAPPERS_EXPORTED int64_t __mlir_apfloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ static_cast(semantics)); \ @@ -42,7 +42,7 @@ /// Binary operations with rounding mode. #define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ - int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \ + MLIR_APFLOAT_WRAPPERS_EXPORTED int64_t __mlir_apfloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ static_cast(semantics)); \ @@ -69,7 +69,7 @@ APFLOAT_BINARY_OP(remainder) #undef APFLOAT_BINARY_OP -void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics, +MLIR_APFLOAT_WRAPPERS_EXPORTED void printApFloat(int32_t semantics, uint64_t a) { const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( static_cast(semantics)); From 9e17a7e1321d67eab5ce6ee7b6edc558b8e3904d Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 10 Nov 2025 20:26:22 -0800 Subject: [PATCH 3/4] one underscore --- .../ArithToAPFloat/ArithToAPFloat.cpp | 5 +-- mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 24 +++++++++----- .../ArithToApfloat/arith-to-apfloat.mlir | 32 +++++++++---------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 012e934d3050f..01fd5b278aca4 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -15,9 +15,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "arith-to-apfloat" namespace mlir { #define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS @@ -59,7 +56,7 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); - std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str(); + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); FunctionType funcT = FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type}); FailureOr func = diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index cbb9b35e6dc9a..0a05f7369e556 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -21,15 +21,24 @@ // #include "llvm/ADT/APFloat.h" -#if (defined(_WIN32) || defined(__CYGWIN__)) -#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport) +#ifdef _WIN32 +#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT +#ifdef mlir_apfloat_wrappers_EXPORTS +// We are building this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport) #else -#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default"))) -#endif +// We are using this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport) +#endif // mlir_apfloat_wrappers_EXPORTS +#endif // MLIR_APFLOAT_WRAPPERS_EXPORT +#else +// Non-windows: use visibility attributes. +#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 /// Binary operations without rounding mode. #define APFLOAT_BINARY_OP(OP) \ - MLIR_APFLOAT_WRAPPERS_EXPORTED int64_t __mlir_apfloat_##OP( \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ static_cast(semantics)); \ @@ -42,7 +51,7 @@ /// Binary operations with rounding mode. #define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ - MLIR_APFLOAT_WRAPPERS_EXPORTED int64_t __mlir_apfloat_##OP( \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ static_cast(semantics)); \ @@ -69,8 +78,7 @@ APFLOAT_BINARY_OP(remainder) #undef APFLOAT_BINARY_OP -MLIR_APFLOAT_WRAPPERS_EXPORTED void printApFloat(int32_t semantics, - uint64_t a) { +MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) { const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( static_cast(semantics)); unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index fe4d28a56f808..797f42c37a26f 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 // CHECK-LABEL: func.func @foo() -> f8E4M3FN { // CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN @@ -12,7 +12,7 @@ // CHECK: return %[[CONSTANT_0]] : f6E3M2FN // CHECK: } -// Illustrate that both f8E4M3FN and f6E3M2FN calling the same __mlir_apfloat_add is fine +// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine // because each gets its own semantics enum and gets bitcast/extui/trunci to its own width. // CHECK-LABEL: func.func @full_example() { // CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN @@ -23,7 +23,7 @@ // CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64 // // fltSemantics semantics for f8E4M3FN // CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 -// CHECK: %[[VAL_1:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 // CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 // CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN // CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN @@ -36,7 +36,7 @@ // CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64 // // fltSemantics semantics for f6E3M2FN // CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32 -// CHECK: %[[VAL_3:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 +// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 // CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6 // CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN // CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN @@ -69,9 +69,9 @@ func.func @full_example() { // ----- -// CHECK: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 // CHECK: %[[sem:.*]] = arith.constant 18 : i32 -// CHECK: call @__mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.addf %arg0, %arg1 : f4E2M1FN return @@ -80,8 +80,8 @@ func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { // ----- // Test decl collision (different type) -// expected-error@+1{{matched function '__mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} -func.func private @__mlir_apfloat_add(i32, i32, f32) -> index +// expected-error@+1{{matched function '_mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} +func.func private @_mlir_apfloat_add(i32, i32, f32) -> index func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.addf %arg0, %arg1 : f4E2M1FN return @@ -89,9 +89,9 @@ func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { // ----- -// CHECK: func.func private @__mlir_apfloat_subtract(i32, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64 // CHECK: %[[sem:.*]] = arith.constant 18 : i32 -// CHECK: call @__mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.subf %arg0, %arg1 : f4E2M1FN return @@ -99,9 +99,9 @@ func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { // ----- -// CHECK: func.func private @__mlir_apfloat_multiply(i32, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64 // CHECK: %[[sem:.*]] = arith.constant 18 : i32 -// CHECK: call @__mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.mulf %arg0, %arg1 : f4E2M1FN return @@ -109,9 +109,9 @@ func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { // ----- -// CHECK: func.func private @__mlir_apfloat_divide(i32, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64 // CHECK: %[[sem:.*]] = arith.constant 18 : i32 -// CHECK: call @__mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.divf %arg0, %arg1 : f4E2M1FN return @@ -119,9 +119,9 @@ func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { // ----- -// CHECK: func.func private @__mlir_apfloat_remainder(i32, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64 // CHECK: %[[sem:.*]] = arith.constant 18 : i32 -// CHECK: call @__mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.remf %arg0, %arg1 : f4E2M1FN return From 55d11cc779b8e73e156c18f64d90da69021f6177 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 11 Nov 2025 00:44:05 -0500 Subject: [PATCH 4/4] add mlir_apfloat_wrappers --- .../Dialect/Arith/CPU/test-apfloat-emulation.mlir | 6 ++++-- mlir/test/lit.cfg.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index a2b3eb73a60b8..2768afe0834b5 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -1,14 +1,16 @@ // Case 1: All floating-point arithmetics is lowered through APFloat. // RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \ // RUN: mlir-runner -e entry --entry-point-result=void \ -// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s // Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat. // Arithmetics on f32 is lowered directly to LLVM. // RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \ // RUN: --convert-to-llvm --reconcile-unrealized-casts | \ // RUN: mlir-runner -e entry --entry-point-result=void \ -// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s // Put rhs into separate function so that it won't be constant-folded. func.func @foo() -> (f8E4M3FN, f32) { diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 6ff12d66523f5..4a38ed605be0c 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -208,6 +208,7 @@ def find_real_python_interpreter(): add_runtime("mlir_c_runner_utils"), add_runtime("mlir_async_runtime"), add_runtime("mlir_float16_utils"), + add_runtime("mlir_apfloat_wrappers"), "mlir-linalg-ods-yaml-gen", "mlir-reduce", "mlir-pdll",