diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h similarity index 73% rename from mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h rename to mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h index 64a42a228199e..6702aca045ba4 100644 --- a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h +++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H -#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H #include @@ -18,4 +18,4 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" } // namespace mlir -#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_ARITHTOAPFLOAT_H diff --git a/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h new file mode 100644 index 0000000000000..6cb44c89ecebb --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h @@ -0,0 +1,21 @@ +//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===// +// +// Part of the APFloat Project, under the Apache License v2.0 with APFloat +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_MATHTOAPFLOAT_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 82bdfd02661a6..7c2b450ca6710 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,8 +11,9 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h" +#include "mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" -#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index fcbaf3ccc1486..7f24e58671aab 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// MathToAPFloat +//===----------------------------------------------------------------------===// + +def MathToAPFloatConversionPass + : Pass<"convert-math-to-apfloat", "ModuleOp"> { + let summary = "Convert Math ops to APFloat runtime library calls"; + let description = [{ + This pass converts supported Math ops to APFloat-based runtime library + calls (APFloatWrappers.cpp). APFloat is a software implementation of + floating-point mathmetic operations. + }]; + let dependentDialects = ["math::MathDialect", "func::FuncDialect"]; +} + //===----------------------------------------------------------------------===// // MathToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 00d50874a2e8d..079c1f461b6ed 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -67,6 +67,22 @@ FailureOr lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables = nullptr); +/// Create a FuncOp decl and insert it into `symTable` operation. If +/// `symbolTables` is provided, then the decl will be inserted into the +/// SymbolTableCollection. +FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr); + +/// Helper function to look up or create the symbol for a runtime library +/// function with the given parameter types. Returns an int64_t, unless a +/// different result type is specified. +FailureOr +lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + TypeRange paramTypes, + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}); + } // namespace func } // namespace mlir diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp similarity index 88% rename from mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp rename to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp index 79816fc6e3bf1..813a854f2fc97 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp @@ -6,8 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" +#include "Utils.h" +#include "mlir/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -25,47 +26,6 @@ namespace mlir { using namespace mlir; using namespace mlir::func; -static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, - StringRef name, FunctionType funcT, bool setPrivate, - SymbolTableCollection *symbolTables = nullptr) { - OpBuilder::InsertionGuard g(b); - assert(!symTable->getRegion(0).empty() && "expected non-empty region"); - b.setInsertionPointToStart(&symTable->getRegion(0).front()); - FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); - if (setPrivate) - funcOp.setPrivate(); - if (symbolTables) { - SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); - symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); - } - return funcOp; -} - -/// Helper function to look up or create the symbol for a runtime library -/// function with the given parameter types. Returns an int64_t, unless a -/// different result type is specified. -static FailureOr -lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, - StringRef name, TypeRange paramTypes, - SymbolTableCollection *symbolTables = nullptr, - Type resultType = {}) { - if (!resultType) - resultType = IntegerType::get(symTable->getContext(), 64); - std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); - auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); - FailureOr func = - lookupFnDecl(symTable, funcName, funcT, symbolTables); - // Failed due to type mismatch. - if (failed(func)) - return func; - // Successfully matched existing decl. - if (*func) - return *func; - - return createFnDecl(b, symTable, funcName, funcT, - /*setPrivate=*/true, symbolTables); -} - /// Helper function to look up or create the symbol for a runtime library /// function for a binary arithmetic operation. /// @@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables = nullptr) { auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); - return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type}, - symbolTables); -} - -static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) { - int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); - return arith::ConstantOp::create(b, loc, b.getI32Type(), - b.getIntegerAttr(b.getI32Type(), sem)); + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + return lookupOrCreateFnDecl(b, symTable, funcName, + {i32Type, i64Type, i64Type}, symbolTables); } /// Given two operands of vector type and vector result type (with the same @@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, intWType, rhs)); // Call APFloat function. - Value semValue = getSemanticsValue(rewriter, loc, floatTy); + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, lhsBits, rhsBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), @@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern { // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); - FailureOr fn = lookupOrCreateApFloatFn( - rewriter, symTable, "convert", {i32Type, i32Type, i64Type}); + FailureOr fn = + lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert", + {i32Type, i32Type, i64Type}); if (failed(fn)) return fn; @@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); // Call APFloat function. - Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy); auto outFloatTy = cast(resultType); - Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value outSemValue = + getAPFloatSemanticsValue(rewriter, loc, outFloatTy); std::array params = {inSemValue, outSemValue, operandBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), @@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern { auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = - lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int", - {i32Type, i32Type, i1Type, i64Type}); + lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int", + {i32Type, i32Type, i1Type, i64Type}); if (failed(fn)) return fn; @@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); // Call APFloat function. - Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy); auto outIntTy = cast(resultType); Value outWidthValue = arith::ConstantOp::create( rewriter, loc, i32Type, @@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern { auto i1Type = IntegerType::get(symTable->getContext(), 1); auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); - FailureOr fn = - lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int", - {i32Type, i32Type, i1Type, i64Type}); + FailureOr fn = lookupOrCreateFnDecl( + rewriter, symTable, "_mlir_apfloat_convert_from_int", + {i32Type, i32Type, i1Type, i64Type}); if (failed(fn)) return fn; @@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern { // Call APFloat function. auto outFloatTy = cast(resultType); - Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value outSemValue = + getAPFloatSemanticsValue(rewriter, loc, outFloatTy); Value inWidthValue = arith::ConstantOp::create( rewriter, loc, i32Type, rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); @@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern { auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); FailureOr fn = - lookupOrCreateApFloatFn(rewriter, symTable, "compare", - {i32Type, i64Type, i64Type}, nullptr, i8Type); + lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); if (failed(fn)) return fn; @@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, intWType, rhs)); // Call APFloat function. - Value semValue = getSemanticsValue(rewriter, loc, floatTy); + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, lhsBits, rhsBits}; Value comparisonResult = func::CallOp::create(rewriter, loc, TypeRange(i8Type), @@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern { // Get APFloat function from runtime library. auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); - FailureOr fn = - lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + FailureOr fn = lookupOrCreateFnDecl( + rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type}); if (failed(fn)) return fn; @@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, intWType, operand1)); // Call APFloat function. - Value semValue = getSemanticsValue(rewriter, loc, floatTy); + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, operandBits}; Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type), diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt new file mode 100644 index 0000000000000..e8fd9c493b975 --- /dev/null +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt @@ -0,0 +1,50 @@ +add_mlir_library(ArithAndMathToAPFloatUtils + Utils.cpp + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + MLIRArithDialect + ) + +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + PARTIAL_SOURCES_INTENDED + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ArithAndMathToAPFloatUtils + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + MLIRVectorDialect + ) + +add_mlir_conversion_library(MLIRMathToAPFloat + MathToAPFloat.cpp + PARTIAL_SOURCES_INTENDED + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRTransformUtils + ArithAndMathToAPFloatUtils + MLIRMathDialect + MLIRFuncDialect + MLIRFuncUtils + ) diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp new file mode 100644 index 0000000000000..784028f5cf2eb --- /dev/null +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/MathToAPFloat.cpp @@ -0,0 +1,219 @@ +//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Utils.h" + +#include "mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +struct AbsFOpToAPFloatConversion final : OpRewritePattern { + AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + // Cast operands to 64-bit integers. + auto operand = op.getOperand(); + auto floatTy = dyn_cast(operand.getType()); + if (!floatTy) + return rewriter.notifyMatchFailure(op, + "only scalar FloatTypes supported"); + if (floatTy.getIntOrFloatBitWidth() > 64) { + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr fn = lookupOrCreateFnDecl( + rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type}); + if (failed(fn)) + return fn; + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand)); + + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, operandBits}; + Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = + arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); + rewriter.replaceOp( + op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); + return success(); + } + + SymbolOpInterface symTable; +}; + +template +struct IsOpToAPFloatConversion final : OpRewritePattern { + IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, + SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable), + APFloatName(APFloatName) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Cast operands to 64-bit integers. + auto operand = op.getOperand(); + auto floatTy = dyn_cast(operand.getType()); + if (!floatTy) + return rewriter.notifyMatchFailure(op, + "only scalar FloatTypes supported"); + if (floatTy.getIntOrFloatBitWidth() > 64) { + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + // Get APFloat function from runtime library. + auto i1 = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + std::string funcName = + (llvm::Twine("_mlir_apfloat_is") + APFloatName).str(); + FailureOr fn = lookupOrCreateFnDecl( + rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1); + if (failed(fn)) + return fn; + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand)); + + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, operandBits}; + rewriter.replaceOpWithNewOp(op, TypeRange(i1), + SymbolRefAttr::get(*fn), params); + return success(); + } + + SymbolOpInterface symTable; + const char *APFloatName; +}; + +struct FmaOpToAPFloatConversion final : OpRewritePattern { + FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable) {}; + + LogicalResult matchAndRewrite(math::FmaOp op, + PatternRewriter &rewriter) const override { + // Cast operands to 64-bit integers. + auto floatTy = cast(op.getResult().getType()); + if (!floatTy) + return rewriter.notifyMatchFailure(op, + "only scalar FloatTypes supported"); + if (floatTy.getIntOrFloatBitWidth() > 64) { + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr fn = lookupOrCreateFnDecl( + rewriter, symTable, "_mlir_apfloat_fused_multiply_add", + {i32Type, i64Type, i64Type, i64Type}); + if (failed(fn)) + return fn; + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value operand = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getA())); + Value multiplicand = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getB())); + Value addend = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getC())); + + // Call APFloat function. + Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, operand, multiplicand, addend}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + rewriter.replaceOpWithNewOp(op, floatTy, truncatedBits); + return success(); + } + + SymbolOpInterface symTable; +}; + +namespace { +struct MathToAPFloatConversionPass final + : impl::MathToAPFloatConversionPassBase { + using Base::Base; + + void runOnOperation() override; +}; + +void MathToAPFloatConversionPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + patterns.add(context, getOperation()); + patterns.add>(context, "finite", + getOperation()); + patterns.add>(context, "infinite", + getOperation()); + patterns.add>(context, "nan", + getOperation()); + patterns.add>(context, "normal", + getOperation()); + patterns.add(context, getOperation()); + + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp new file mode 100644 index 0000000000000..2b5857367dc40 --- /dev/null +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.cpp @@ -0,0 +1,22 @@ +//===- Utils.cpp - Utils for APFloat Conversion ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" + +mlir::Value mlir::getAPFloatSemanticsValue(OpBuilder &b, Location loc, + FloatType floatTy) { + int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + return arith::ConstantOp::create(b, loc, b.getI32Type(), + b.getIntegerAttr(b.getI32Type(), sem)); +} diff --git a/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h new file mode 100644 index 0000000000000..5f11d24261b43 --- /dev/null +++ b/mlir/lib/Conversion/ArithAndMathToAPFloat/Utils.h @@ -0,0 +1,21 @@ +//===- Utils.h - Utils for APFloat Conversion - C++ -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ +#define MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ + +namespace mlir { +class Value; +class OpBuilder; +class Location; +class FloatType; + +Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy); +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHANDMATHTOAPFLOAT_UTILS_H_ diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt deleted file mode 100644 index 31fce7a4de8a2..0000000000000 --- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(MLIRArithToAPFloat - ArithToAPFloat.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRArithDialect - MLIRArithTransforms - MLIRFuncDialect - MLIRFuncUtils - MLIRVectorDialect - ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 613dc6d242ceb..2ed10effb53da 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,7 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) -add_subdirectory(ArithToAPFloat) +add_subdirectory(ArithAndMathToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index a08cc98e4d5bf..e5f8763127a1b 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -37,7 +37,9 @@ using ConvertFMFMathToLLVMPattern = VectorConvertToLLVMPattern; -using AbsFOpLowering = ConvertFMFMathToLLVMPattern; +using AbsFOpLowering = + ConvertFMFMathToLLVMPattern; using CeilOpLowering = ConvertFMFMathToLLVMPattern; using CopySignOpLowering = ConvertFMFMathToLLVMPattern; @@ -52,7 +54,8 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern; using ExpOpLowering = ConvertFMFMathToLLVMPattern; using FloorOpLowering = ConvertFMFMathToLLVMPattern; -using FmaOpLowering = ConvertFMFMathToLLVMPattern; +using FmaOpLowering = ConvertFMFMathToLLVMPattern; using Log10OpLowering = ConvertFMFMathToLLVMPattern; using Log2OpLowering = ConvertFMFMathToLLVMPattern; diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index d6dfd0229963c..7dc12adad0531 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -279,3 +279,42 @@ func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, } return func; } + +func::FuncOp func::createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, + bool setPrivate, + SymbolTableCollection *symbolTables) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + func::FuncOp funcOp = + func::FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +FailureOr +func::lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef funcName, TypeRange paramTypes, + SymbolTableCollection *symbolTables, + Type resultType) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); + FailureOr func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index f3e38eb8ffa2d..9deb900fbe35d 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -21,6 +21,7 @@ // #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APSInt.h" +#include "llvm/Support/Debug.h" #ifdef _WIN32 #ifndef MLIR_APFLOAT_WRAPPERS_EXPORT @@ -143,7 +144,8 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, return static_cast(x.compare(y)); } -MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) { +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, + uint64_t a) { const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( static_cast(semantics)); unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); @@ -152,6 +154,67 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint6 return x.bitcastToAPInt().getZExtValue(); } +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_abs(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + return abs(x).bitcastToAPInt().getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isfinite(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + return x.isFinite(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isinfinite(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + return x.isInfinity(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnormal(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + return x.isNormal(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT bool _mlir_apfloat_isnan(int32_t semantics, + uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + return x.isNaN(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t +_mlir_apfloat_fused_multiply_add(int32_t semantics, uint64_t operand, + uint64_t multiplicand, uint64_t addend) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat operand_(sem, llvm::APInt(bitWidth, operand)); + llvm::APFloat multiplicand_(sem, llvm::APInt(bitWidth, multiplicand)); + llvm::APFloat addend_(sem, llvm::APInt(bitWidth, addend)); + llvm::detail::opStatus stat = operand_.fusedMultiplyAdd( + multiplicand_, addend_, llvm::RoundingMode::NearestTiesToEven); + assert(stat == llvm::APFloatBase::opOK && + "expected fusedMultiplyAdd status to be OK"); + return operand_.bitcastToAPInt().getZExtValue(); +} + /// Min/max operations. #define APFLOAT_MIN_MAX_OP(OP) \ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir similarity index 100% rename from mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir rename to mlir/test/Conversion/ArithAndMathToAPFloat/arith-to-apfloat.mlir diff --git a/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir new file mode 100644 index 0000000000000..e52eb5866093c --- /dev/null +++ b/mlir/test/Conversion/ArithAndMathToAPFloat/math-to-apfloat.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s --convert-math-to-apfloat | FileCheck %s + +func.func @full_example() { + %neg14fp8 = arith.constant -1.4 : f8E4M3FN + %abs = math.absf %neg14fp8 : f8E4M3FN + + // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA) + %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU + %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU + %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU + %fma = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU + + %isinf = math.isinf %neg14fp8 : f8E4M3FN + %isnan = math.isnan %neg14fp8 : f8E4M3FN + %isnormal = math.isnormal %neg14fp8 : f8E4M3FN + %isfinite = math.isfinite %neg14fp8 : f8E4M3FN + + return +} + +// CHECK-LABEL: func.func private @_mlir_apfloat_isfinite(i32, i64) -> i1 +// CHECK: func.func private @_mlir_apfloat_isnormal(i32, i64) -> i1 +// CHECK: func.func private @_mlir_apfloat_isnan(i32, i64) -> i1 +// CHECK: func.func private @_mlir_apfloat_isinfinite(i32, i64) -> i1 +// CHECK: func.func private @_mlir_apfloat_fused_multiply_add(i32, i64, i64, i64) -> i64 +// CHECK: func.func private @_mlir_apfloat_abs(i32, i64) -> i64 + +// CHECK-LABEL: func.func @full_example() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant -1.375000e+00 : f8E4M3FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_0:.*]] = call @_mlir_apfloat_abs(%[[CONSTANT_1]], %[[EXTUI_0]]) : (i32, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_0]] : i64 to i8 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.000000e+00 : f8E8M0FNU +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 4.000000e+00 : f8E8M0FNU +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 8.000000e+00 : f8E8M0FNU +// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[CONSTANT_3]] : f8E8M0FNU to i8 +// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_2]] : i8 to i64 +// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f8E8M0FNU to i8 +// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i8 to i64 +// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[CONSTANT_4]] : f8E8M0FNU to i8 +// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i8 to i64 +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 15 : i32 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_fused_multiply_add(%[[CONSTANT_5]], %[[EXTUI_1]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64, i64) -> i64 +// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 +// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i8 to f8E8M0FNU +// CHECK: %[[BITCAST_6:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_4:.*]] = arith.extui %[[BITCAST_6]] : i8 to i64 +// CHECK: %[[CONSTANT_6:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_2:.*]] = call @_mlir_apfloat_isinfinite(%[[CONSTANT_6]], %[[EXTUI_4]]) : (i32, i64) -> i1 +// CHECK: %[[BITCAST_7:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_5:.*]] = arith.extui %[[BITCAST_7]] : i8 to i64 +// CHECK: %[[CONSTANT_7:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_isnan(%[[CONSTANT_7]], %[[EXTUI_5]]) : (i32, i64) -> i1 +// CHECK: %[[BITCAST_8:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_6:.*]] = arith.extui %[[BITCAST_8]] : i8 to i64 +// CHECK: %[[CONSTANT_8:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_4:.*]] = call @_mlir_apfloat_isnormal(%[[CONSTANT_8]], %[[EXTUI_6]]) : (i32, i64) -> i1 +// CHECK: %[[BITCAST_9:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_7:.*]] = arith.extui %[[BITCAST_9]] : i8 to i64 +// CHECK: %[[CONSTANT_9:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_5:.*]] = call @_mlir_apfloat_isfinite(%[[CONSTANT_9]], %[[EXTUI_7]]) : (i32, i64) -> i1 +// CHECK: return +// CHECK: } diff --git a/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir new file mode 100644 index 0000000000000..c890b470b563a --- /dev/null +++ b/mlir/test/Integration/Dialect/Math/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,68 @@ +// REQUIRES: system-linux || system-darwin +// TODO: Run only on Linux and MacOS until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// RUN: mlir-opt %s --convert-math-to-apfloat --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +func.func @entry() { + + // FP8 + + %neg14fp8 = arith.constant -1.4 : f8E4M3FN + %absfp8 = math.absf %neg14fp8 : f8E4M3FN + // CHECK: 1.375 + vector.print %absfp8 : f8E4M3FN + + // see llvm/unittests/ADT/APFloatTest::TEST(APFloatTest, Float8E8M0FNUFMA) + %twof8E8M0FNU = arith.constant 2.0 : f8E8M0FNU + %fourf8E8M0FNU = arith.constant 4.0 : f8E8M0FNU + %eightf8E8M0FNU = arith.constant 8.0 : f8E8M0FNU + %fmafp8 = math.fma %fourf8E8M0FNU, %twof8E8M0FNU, %eightf8E8M0FNU : f8E8M0FNU + // CHECK: 16 + vector.print %fmafp8 : f8E8M0FNU + + // CHECK: 0 + %isinffp8 = math.isinf %neg14fp8 : f8E4M3FN + vector.print %isinffp8 : i1 + // CHECK: 0 + %isnanfp8 = math.isnan %neg14fp8 : f8E4M3FN + vector.print %isnanfp8 : i1 + %isnormalfp8 = math.isnormal %neg14fp8 : f8E4M3FN + // CHECK: 1 + vector.print %isnormalfp8 : i1 + %isfinitefp8 = math.isfinite %neg14fp8 : f8E4M3FN + // CHECK: 1 + vector.print %isfinitefp8 : i1 + + // FP32 + + %neg14fp32 = arith.constant -1.4 : f32 + %absfp32 = math.absf %neg14fp32 : f32 + // CHECK: 1.4 + vector.print %absfp32 : f32 + + %twofp32 = arith.constant 2.0 : f32 + %fourfp32 = arith.constant 4.0 : f32 + %eightfp32 = arith.constant 8.0 : f32 + %fmafp32 = math.fma %fourfp32, %twofp32, %eightfp32 : f32 + // CHECK: 16 + vector.print %fmafp32 : f32 + + // CHECK: 0 + %isinffp32 = math.isinf %neg14fp32 : f32 + vector.print %isinffp32 : i1 + // CHECK: 0 + %isnanfp32 = math.isnan %neg14fp32 : f32 + vector.print %isnanfp32 : i1 + %isnormalfp32 = math.isnormal %neg14fp32 : f32 + // CHECK: 1 + vector.print %isnormalfp32 : i1 + %isfinitefp32 = math.isfinite %neg14fp32 : f32 + // CHECK: 1 + vector.print %isfinitefp32 : i1 + + return +}