Skip to content
Merged
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
//
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ArithToAPFloat
//===----------------------------------------------------------------------===//

def ArithToAPFloatConversionPass
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
let summary = "Convert Arith ops to APFloat runtime library calls";
let description = [{
This pass converts supported Arith ops to APFloat-based runtime library
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
let dependentDialects = ["func::FuncDialect"];
}

//===----------------------------------------------------------------------===//
// ArithToSPIRV
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Func/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
mlir::ModuleOp moduleOp);

/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
/// `name`. Return a failure if the FuncOp is found but with a different
/// signature.
FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);

} // namespace func
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);

/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
Expand Down
161 changes: 161 additions & 0 deletions mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "arith-to-apfloat"

namespace mlir {
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::func;

static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, FunctionType funcT, bool setPrivate,
SymbolTableCollection *symbolTables = nullptr) {
OpBuilder::InsertionGuard g(b);
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&symTable->getRegion(0).front());
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
if (setPrivate)
funcOp.setPrivate();
if (symbolTables) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
}
return funcOp;
}

/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
/// Parameter 1: APFloat semantics
/// Parameter 2: Left-hand side operand
/// Parameter 3: Right-hand side operand
///
/// This function will return a failure if the function is found but has an
/// unexpected signature.
///
static FailureOr<FuncOp>
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);

std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str();
FunctionType funcT =
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
FailureOr<FuncOp> func =
lookupFnDecl(symTable, funcName, funcT, symbolTables);
// Failed due to type mismatch.
if (failed(func))
return func;
// Successfully matched existing decl.
if (*func)
return *func;

return createFnDecl(b, symTable, funcName, funcT,
/*setPrivate=*/true, symbolTables);
}

/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy, const char *APFloatName>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
SymbolOpInterface symTable)
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
if (failed(fn))
return fn;

rewriter.setInsertionPoint(op);
// Cast operands to 64-bit integers.
Location loc = op.getLoc();
auto floatTy = cast<FloatType>(op.getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
auto int64Type = rewriter.getI64Type();
Value lhsBits = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
Value rhsBits = arith::ExtUIOp::create(
rewriter, loc, int64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));

// Call APFloat function.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
Value semValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp =
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
SymbolRefAttr::get(*fn), params);

// Truncate result to the original width.
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
resultOp->getResult(0));
rewriter.replaceOp(
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
return success();
}

SymbolOpInterface symTable;
};

namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
using Base::Base;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
static const char add[] = "add";
static const char subtract[] = "subtract";
static const char multiply[] = "multiply";
static const char divide[] = "divide";
static const char remainder[] = "remainder";
patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
context, 1, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
result = failure();
}
// NB: if you don't return failure, no other diag handlers will fire (see
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
return failure();
});
Comment on lines +148 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is something the walk rewriter should do? (separately from this PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents: leaving it decomposed is better - this is not hard to write (once you know how diaghandlers work 😜).

But it's up to you - you're code owner/designer of that. If you so wish it. I can refactor this into the rewriter. Maybe the "abstraction" could be passing just the callback.

walkAndApplyPatterns(getOperation(), std::move(patterns));
if (failed(result))
return signalPassFailure();
Comment on lines +156 to +158
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer this is what we want right? if there's a failure due to overlapping decls we want to fail the pass right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable to me.

}
};
} // namespace
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRArithToAPFloat
ArithToAPFloat.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRArithTransforms
MLIRFuncDialect
)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
}
} else if (auto floatTy = dyn_cast<FloatType>(printType)) {
// Print other floating-point types using the APFloat runtime library.
int32_t sem =
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
Value semValue = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
Value floatBits =
LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
printer =
LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
emitCall(rewriter, loc, printer.value(),
ValueRange({semValue, floatBits}));
return success();
} else {
return failure();
}
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Func/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,

return std::make_pair(*newFuncOpOrFailure, newCallOp);
}

FailureOr<func::FuncOp>
func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT, SymbolTableCollection *symbolTables) {
FuncOp func;
if (symbolTables) {
func = symbolTables->lookupSymbolIn<FuncOp>(
symTable, StringAttr::get(symTable->getContext(), name));
} else {
func = llvm::dyn_cast_or_null<FuncOp>(
SymbolTable::lookupSymbolIn(symTable, name));
}

if (!func)
return func;

mlir::FunctionType foundFuncT = func.getFunctionType();
// Assert the signature of the found function is same as expected
if (funcT != foundFuncT) {
return func.emitError("matched function '")
<< name << "' but with different type: " << foundFuncT
<< " (expected " << funcT << ")";
}
return func;
}
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
Expand Down Expand Up @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}

FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintApFloat,
{IntegerType::get(moduleOp->getContext(), 32),
IntegerType::get(moduleOp->getContext(), 64)},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}

static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
Expand Down
Loading