-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Add FP software implementation lowering pass: arith-to-apfloat
#166618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
949a4e9
a93ba6c
0055cb0
38ae089
0f3b820
b713f60
78df4a8
1180064
930a664
b644bfb
d94e2d1
7d29b71
c79a9ef
64eff4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| //===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===// | ||
| // | ||
| // Part of the APFloat Project, under the Apache License v2.0 with APFloat | ||
| // Exceptions. See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H | ||
| #define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H | ||
|
|
||
| #include <memory> | ||
|
|
||
| namespace mlir { | ||
| class Pass; | ||
|
|
||
| #define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS | ||
| #include "mlir/Conversion/Passes.h.inc" | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| //===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" | ||
|
|
||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/Func/Utils/Utils.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/IR/Verifier.h" | ||
| #include "mlir/Transforms/WalkPatternRewriteDriver.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
| #define DEBUG_TYPE "arith-to-apfloat" | ||
|
|
||
| namespace mlir { | ||
| #define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS | ||
| #include "mlir/Conversion/Passes.h.inc" | ||
| } // namespace mlir | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::func; | ||
|
|
||
| static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, | ||
| StringRef name, FunctionType funcT, bool setPrivate, | ||
| SymbolTableCollection *symbolTables = nullptr) { | ||
| OpBuilder::InsertionGuard g(b); | ||
| assert(!symTable->getRegion(0).empty() && "expected non-empty region"); | ||
| b.setInsertionPointToStart(&symTable->getRegion(0).front()); | ||
| FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); | ||
| if (setPrivate) | ||
| funcOp.setPrivate(); | ||
| if (symbolTables) { | ||
| SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); | ||
| symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); | ||
| } | ||
| return funcOp; | ||
| } | ||
|
|
||
| /// Helper function to look up or create the symbol for a runtime library | ||
| /// function for a binary arithmetic operation. | ||
| /// | ||
| /// Parameter 1: APFloat semantics | ||
| /// Parameter 2: Left-hand side operand | ||
| /// Parameter 3: Right-hand side operand | ||
| /// | ||
| /// This function will return a failure if the function is found but has an | ||
| /// unexpected signature. | ||
| /// | ||
| static FailureOr<FuncOp> | ||
| lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, | ||
| SymbolTableCollection *symbolTables = nullptr) { | ||
| auto i32Type = IntegerType::get(symTable->getContext(), 32); | ||
| auto i64Type = IntegerType::get(symTable->getContext(), 64); | ||
|
|
||
| std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str(); | ||
| FunctionType funcT = | ||
| FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type}); | ||
| FailureOr<FuncOp> func = | ||
| lookupFnDecl(symTable, funcName, funcT, symbolTables); | ||
| // Failed due to type mismatch. | ||
| if (failed(func)) | ||
| return func; | ||
| // Successfully matched existing decl. | ||
| if (*func) | ||
| return *func; | ||
|
|
||
| return createFnDecl(b, symTable, funcName, funcT, | ||
| /*setPrivate=*/true, symbolTables); | ||
| } | ||
|
|
||
| /// Rewrite a binary arithmetic operation to an APFloat function call. | ||
| template <typename OpTy, const char *APFloatName> | ||
makslevental marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> { | ||
| BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit, | ||
| SymbolOpInterface symTable) | ||
| : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}; | ||
|
|
||
| LogicalResult matchAndRewrite(OpTy op, | ||
| PatternRewriter &rewriter) const override { | ||
| // Get APFloat function from runtime library. | ||
| FailureOr<FuncOp> fn = | ||
| lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); | ||
| if (failed(fn)) | ||
| return fn; | ||
|
|
||
| rewriter.setInsertionPoint(op); | ||
| // Cast operands to 64-bit integers. | ||
| Location loc = op.getLoc(); | ||
| auto floatTy = cast<FloatType>(op.getType()); | ||
| auto intWType = rewriter.getIntegerType(floatTy.getWidth()); | ||
| auto int64Type = rewriter.getI64Type(); | ||
| Value lhsBits = arith::ExtUIOp::create( | ||
| rewriter, loc, int64Type, | ||
| arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); | ||
| Value rhsBits = arith::ExtUIOp::create( | ||
| rewriter, loc, int64Type, | ||
| arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); | ||
|
|
||
| // Call APFloat function. | ||
| int32_t sem = | ||
| llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); | ||
| Value semValue = arith::ConstantOp::create( | ||
| rewriter, loc, rewriter.getI32Type(), | ||
| rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); | ||
| SmallVector<Value> params = {semValue, lhsBits, rhsBits}; | ||
| auto resultOp = | ||
| func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), | ||
| SymbolRefAttr::get(*fn), params); | ||
|
|
||
| // Truncate result to the original width. | ||
| Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, | ||
| resultOp->getResult(0)); | ||
| rewriter.replaceOp( | ||
| op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); | ||
| return success(); | ||
| } | ||
|
|
||
| SymbolOpInterface symTable; | ||
| }; | ||
|
|
||
| namespace { | ||
| struct ArithToAPFloatConversionPass final | ||
| : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { | ||
| using Base::Base; | ||
|
|
||
| void runOnOperation() override { | ||
| MLIRContext *context = &getContext(); | ||
| RewritePatternSet patterns(context); | ||
| static const char add[] = "add"; | ||
| static const char subtract[] = "subtract"; | ||
| static const char multiply[] = "multiply"; | ||
| static const char divide[] = "divide"; | ||
| static const char remainder[] = "remainder"; | ||
| patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>, | ||
| BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>, | ||
| BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>, | ||
| BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>, | ||
| BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>( | ||
| context, 1, getOperation()); | ||
| LogicalResult result = success(); | ||
| ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { | ||
| if (diag.getSeverity() == DiagnosticSeverity::Error) { | ||
| result = failure(); | ||
| } | ||
| // NB: if you don't return failure, no other diag handlers will fire (see | ||
| // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). | ||
| return failure(); | ||
| }); | ||
|
Comment on lines
+148
to
+155
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this is something the walk rewriter should do? (separately from this PR)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My 2 cents: leaving it decomposed is better - this is not hard to write (once you know how diaghandlers work 😜). But it's up to you - you're code owner/designer of that. If you so wish it. I can refactor this into the rewriter. Maybe the "abstraction" could be passing just the callback. |
||
| walkAndApplyPatterns(getOperation(), std::move(patterns)); | ||
| if (failed(result)) | ||
| return signalPassFailure(); | ||
|
Comment on lines
+156
to
+158
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @matthias-springer this is what we want right? if there's a failure due to overlapping decls we want to fail the pass right?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds reasonable to me. |
||
| } | ||
| }; | ||
| } // namespace | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| add_mlir_conversion_library(MLIRArithToAPFloat | ||
| ArithToAPFloat.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM | ||
|
|
||
| DEPENDS | ||
| MLIRConversionPassIncGen | ||
|
|
||
| LINK_COMPONENTS | ||
| Core | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIRArithDialect | ||
| MLIRArithTransforms | ||
| MLIRFuncDialect | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.