diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h new file mode 100644 index 0000000000000..64a42a228199e --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h @@ -0,0 +1,21 @@ +//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===// +// +// Part of the APFloat Project, under the Apache License v2.0 with APFloat +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 40d866ec7bf10..82bdfd02661a6 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 70e3e45c225db..79bc380dbcb7a 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// ArithToAPFloat +//===----------------------------------------------------------------------===// + +def ArithToAPFloatConversionPass + : Pass<"convert-arith-to-apfloat", "ModuleOp"> { + let summary = "Convert Arith ops to APFloat runtime library calls"; + let description = [{ + This pass converts supported Arith ops to APFloat-based runtime library + calls (APFloatWrappers.cpp). APFloat is a software implementation of + floating-point arithmetic operations. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + //===----------------------------------------------------------------------===// // ArithToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 3576126a487ac..00d50874a2e8d 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -60,6 +60,13 @@ mlir::FailureOr> deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp); +/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name +/// `name`. Return a failure if the FuncOp is found but with a different +/// signature. +FailureOr lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, + SymbolTableCollection *symbolTables = nullptr); + } // namespace func } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 8ad9ed18acebd..b09d32022e348 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, FailureOr lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); +FailureOr +lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); + /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000000000..699edb188a70a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,163 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +/// Helper function to look up or create the symbol for a runtime library +/// function for a binary arithmetic operation. +/// +/// Parameter 1: APFloat semantics +/// Parameter 2: Left-hand side operand +/// Parameter 3: Right-hand side operand +/// +/// This function will return a failure if the function is found but has an +/// unexpected signature. +/// +static FailureOr +lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + SymbolTableCollection *symbolTables = nullptr) { + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + FunctionType funcT = + FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type}); + FailureOr func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + +/// Rewrite a binary arithmetic operation to an APFloat function call. +template +struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { + BinaryArithOpToAPFloatConversion(MLIRContext *context, + const char *APFloatName, + SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable), + APFloatName(APFloatName) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + FailureOr fn = + lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); + if (failed(fn)) + return fn; + + rewriter.setInsertionPoint(op); + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + auto floatTy = cast(op.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call APFloat function. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + SmallVector params = {semValue, lhsBits, rhsBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + rewriter.replaceOp( + op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); + return success(); + } + + SymbolOpInterface symTable; + const char *APFloatName; +}; + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase { + using Base::Base; + + void runOnOperation() override; +}; + +void ArithToAPFloatConversionPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add>(context, "add", + getOperation()); + patterns.add>( + context, "subtract", getOperation()); + patterns.add>( + context, "multiply", getOperation()); + patterns.add>( + context, "divide", getOperation()); + patterns.add>( + context, "remainder", getOperation()); + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000000000..b5ec49c087163 --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index b6099902cc337..f2bacc3399144 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8fff3f9..613dc6d242ceb 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ecd101f..c747e1b59558a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { return failure(); } } + } else if (auto floatTy = dyn_cast(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb0932ef631..d6dfd0229963c 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr +func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, SymbolTableCollection *symbolTables) { + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn( + symTable, StringAttr::get(symTable->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symTable, name)); + } + + if (!func) + return func; + + mlir::FunctionType foundFuncT = func.getFunctionType(); + // Assert the signature of the found function is same as expected + if (funcT != foundFuncT) { + return func.emitError("matched function '") + << name << "' but with different type: " << foundFuncT + << " (expected " << funcT << ")"; + } + return func; +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa34897b6..160b6ae89215c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp new file mode 100644 index 0000000000000..0a05f7369e556 --- /dev/null +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -0,0 +1,89 @@ +//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the APFloat infrastructure to MLIR programs as a runtime +// library. APFloat is a software implementation of floating point arithmetics. +// +// On the MLIR side, floating-point values must be bitcasted to 64-bit integers +// before calling a runtime function. If a floating-point type has less than +// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an +// integer. +// +// Runtime functions receive the floating-point operands of the arithmeic +// operation in the form of 64-bit integers, along with the APFloat semantics +// in the form of a 32-bit integer, which will be interpreted as an +// APFloatBase::Semantics enum value. +// +#include "llvm/ADT/APFloat.h" + +#ifdef _WIN32 +#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT +#ifdef mlir_apfloat_wrappers_EXPORTS +// We are building this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport) +#else +// We are using this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport) +#endif // mlir_apfloat_wrappers_EXPORTS +#endif // MLIR_APFLOAT_WRAPPERS_EXPORT +#else +// Non-windows: use visibility attributes. +#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +/// Binary operations without rounding mode. +#define APFLOAT_BINARY_OP(OP) \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +/// Binary operations with rounding mode. +#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs, ROUNDING_MODE); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +extern "C" { + +#define BIN_OPS_WITH_ROUNDING(X) \ + X(add, llvm::RoundingMode::NearestTiesToEven) \ + X(subtract, llvm::RoundingMode::NearestTiesToEven) \ + X(multiply, llvm::RoundingMode::NearestTiesToEven) \ + X(divide, llvm::RoundingMode::NearestTiesToEven) + +BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE) +#undef BIN_OPS_WITH_ROUNDING +#undef APFLOAT_BINARY_OP_ROUNDING_MODE + +APFLOAT_BINARY_OP(remainder) + +#undef APFLOAT_BINARY_OP + +MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + double d = x.convertToDouble(); + fprintf(stdout, "%lg", d); +} +} diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fdeb4dacf9278..c813a431270d0 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + APFloatWrappers.cpp ArmRunnerUtils.cpp ArmSMEStubs.cpp AsyncRuntime.cpp @@ -167,6 +168,26 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # TODO: This support library is only used on Linux builds until we figure + # out how to hide LLVM symbols in a way that works for all platforms. + add_mlir_library(mlir_apfloat_wrappers + SHARED + APFloatWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + set_target_properties( + mlir_apfloat_wrappers + PROPERTIES CXX_STANDARD 17 + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + ) + target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS) + # Hide LLVM symbols to avoid ODR violations. + target_link_options(mlir_apfloat_wrappers PRIVATE "-Wl,--exclude-libs,ALL") + endif() + add_subdirectory(SparseTensor) add_mlir_library(mlir_c_runner_utils diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir new file mode 100644 index 0000000000000..797f42c37a26f --- /dev/null +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -0,0 +1,128 @@ +// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 + +// CHECK-LABEL: func.func @foo() -> f8E4M3FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN +// CHECK: return %[[CONSTANT_0]] : f8E4M3FN +// CHECK: } + +// CHECK-LABEL: func.func @bar() -> f6E3M2FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN +// CHECK: return %[[CONSTANT_0]] : f6E3M2FN +// CHECK: } + +// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine +// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width. +// CHECK-LABEL: func.func @full_example() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN +// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64 +// // fltSemantics semantics for f8E4M3FN +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 +// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN +// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN + +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN +// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN +// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64 +// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64 +// // fltSemantics semantics for f6E3M2FN +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32 +// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6 +// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN +// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN +// CHECK: return +// CHECK: } + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> f8E4M3FN { + %cst = arith.constant 2.2 : f8E4M3FN + return %cst : f8E4M3FN +} + +func.func @bar() -> f6E3M2FN { + %cst = arith.constant 3.2 : f6E3M2FN + return %cst : f6E3M2FN +} + +func.func @full_example() { + %a = arith.constant 1.4 : f8E4M3FN + %b = func.call @foo() : () -> (f8E4M3FN) + %c = arith.addf %a, %b : f8E4M3FN + vector.print %c : f8E4M3FN + + %d = arith.constant 2.4 : f6E3M2FN + %e = func.call @bar() : () -> (f6E3M2FN) + %f = arith.addf %d, %e : f6E3M2FN + vector.print %f : f6E3M2FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// Test decl collision (different type) +// expected-error@+1{{matched function '_mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} +func.func private @_mlir_apfloat_add(i32, i32, f32) -> index +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.subf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.mulf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.divf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.remf %arg0, %arg1 : f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir new file mode 100644 index 0000000000000..dbaa20346a03a --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,40 @@ +// REQUIRES: system-linux +// TODO: Run only on Linux until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// Case 1: All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat. +// Arithmetics on f32 is lowered directly to LLVM. +// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \ +// RUN: --convert-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> (f8E4M3FN, f32) { + %cst1 = arith.constant 2.2 : f8E4M3FN + %cst2 = arith.constant 2.2 : f32 + return %cst1, %cst2 : f8E4M3FN, f32 +} + +func.func @entry() { + %a1 = arith.constant 1.4 : f8E4M3FN + %a2 = arith.constant 1.4 : f32 + %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32) + %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM + %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM + + // CHECK: 3.5 + vector.print %c1 : f8E4M3FN + + // CHECK: 3.6 + vector.print %c2 : f32 + + return +} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 6ff12d66523f5..7081c51994ec1 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -214,6 +214,11 @@ def find_real_python_interpreter(): "not", ] +if "Linux" in config.host_os: + # TODO: Run only on Linux until we figure out how to build + # mlir_apfloat_wrappers in a platform-independent way. + tools.extend([add_runtime("mlir_apfloat_wrappers")]) + if config.enable_vulkan_runner: tools.extend([add_runtime("mlir_vulkan_runtime")])