Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
//
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ArithToAPFloat
//===----------------------------------------------------------------------===//

def ArithToAPFloatConversionPass
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
let summary = "Convert Arith ops to APFloat runtime library calls";
let description = [{
This pass converts supported Arith ops to APFloat-based runtime library
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
let dependentDialects = ["func::FuncDialect"];
}

//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Func/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);

/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
/// `name`. Return a failure if the FuncOp is found but with a different
/// signature.
FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);

} // namespace func
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);

/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
Expand Down
158 changes: 158 additions & 0 deletions mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir {
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::func;

static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, FunctionType funcT, bool setPrivate,
SymbolTableCollection *symbolTables = nullptr) {
OpBuilder::InsertionGuard g(b);
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&symTable->getRegion(0).front());
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
if (setPrivate)
funcOp.setPrivate();
if (symbolTables) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
}
return funcOp;
}

/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
/// Parameter 1: APFloat semantics
/// Parameter 2: Left-hand side operand
/// Parameter 3: Right-hand side operand
///
/// This function will return a failure if the function is found but has an
/// unexpected signature.
///
static FailureOr<FuncOp>
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);

std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
FunctionType funcT =
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
FailureOr<FuncOp> func =
lookupFnDecl(symTable, funcName, funcT, symbolTables);
// Failed due to type mismatch.
if (failed(func))
return func;
// Successfully matched existing decl.
if (*func)
return *func;

return createFnDecl(b, symTable, funcName, funcT,
/*setPrivate=*/true, symbolTables);
}

/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy, const char *APFloatName>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
SymbolOpInterface symTable)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
if (failed(fn))
return fn;

rewriter.setInsertionPoint(op);
// Cast operands to 64-bit integers.
Location loc = op.getLoc();
auto floatTy = cast<FloatType>(op.getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
auto int64Type = rewriter.getI64Type();
Value lhsBits = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
Value rhsBits = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));

// Call APFloat function.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
Value semValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp =
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
SymbolRefAttr::get(*fn), params);

// Truncate result to the original width.
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
resultOp->getResult(0));
rewriter.replaceOp(
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
return success();
}

SymbolOpInterface symTable;
};

namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
using Base::Base;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
static const char add[] = "add";
static const char subtract[] = "subtract";
static const char multiply[] = "multiply";
static const char divide[] = "divide";
static const char remainder[] = "remainder";
patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
context, 1, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
result = failure();
}
// NB: if you don't return failure, no other diag handlers will fire (see
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
return failure();
});
walkAndApplyPatterns(getOperation(), std::move(patterns));
if (failed(result))
return signalPassFailure();
}
};
} // namespace
18 changes: 18 additions & 0 deletions mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRArithToAPFloat
ArithToAPFloat.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRArithTransforms
MLIRFuncDialect
MLIRFuncUtils
)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
} else if (auto floatTy = dyn_cast<FloatType>(printType)) {
// Print other floating-point types using the APFloat runtime library.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
Value semValue = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
Value floatBits =
LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
printer =
LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
emitCall(rewriter, loc, printer.value(),
ValueRange({semValue, floatBits}));
return success();
} else {
return failure();
}
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Func/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,

return std::make_pair(*newFuncOpOrFailure, newCallOp);
}

FailureOr<func::FuncOp>
func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT, SymbolTableCollection *symbolTables) {
FuncOp func;
if (symbolTables) {
func = symbolTables->lookupSymbolIn<FuncOp>(
symTable, StringAttr::get(symTable->getContext(), name));
} else {
func = llvm::dyn_cast_or_null<FuncOp>(
SymbolTable::lookupSymbolIn(symTable, name));
}

if (!func)
return func;

mlir::FunctionType foundFuncT = func.getFunctionType();
// Assert the signature of the found function is same as expected
if (funcT != foundFuncT) {
return func.emitError("matched function '")
<< name << "' but with different type: " << foundFuncT
<< " (expected " << funcT << ")";
}
return func;
}
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
Expand Down Expand Up @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}

FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintApFloat,
{IntegerType::get(moduleOp->getContext(), 32),
IntegerType::get(moduleOp->getContext(), 64)},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}

static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
Expand Down
Loading
Loading