-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir] Add patterns to lower Math operations to LLVM based libm calls.
Some Math operations do not have an equivalent in LLVM. In these cases, allow a low priority fallback of calling the libm functions. This is to give functionality and is not a performant option. Differential Revision: https://reviews.llvm.org/D100367
- Loading branch information
Showing
7 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//===- MathToLibm.h - Utils to convert from the complex dialect --------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ | ||
#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ | ||
|
||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
template <typename T> | ||
class OperationPass; | ||
|
||
/// Populate the given list with patterns that convert from Math to Libm calls. | ||
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, | ||
PatternBenefit benefit); | ||
|
||
/// Create a pass to convert Math operations to libm calls. | ||
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass(); | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
add_mlir_conversion_library(MLIRMathToLibm | ||
MathToLibm.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLibm | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRMath | ||
MLIRStandardOpsTransforms | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/MathToLibm/MathToLibm.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | ||
#include "mlir/Dialect/Vector/VectorOps.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
// Pattern to convert vector operations to scalar operations. This is needed as | ||
// libm calls require scalars. | ||
template <typename Op> | ||
struct VecOpToScalarOp : public OpRewritePattern<Op> { | ||
public: | ||
using OpRewritePattern<Op>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; | ||
}; | ||
// Pattern to convert scalar math operations to calls to libm functions. | ||
// Additionally the libm function signatures are declared. | ||
template <typename Op> | ||
struct ScalarOpToLibmCall : public OpRewritePattern<Op> { | ||
public: | ||
using OpRewritePattern<Op>::OpRewritePattern; | ||
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc, | ||
StringRef doubleFunc, PatternBenefit benefit) | ||
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc), | ||
doubleFunc(doubleFunc){}; | ||
|
||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; | ||
|
||
private: | ||
std::string floatFunc, doubleFunc; | ||
}; | ||
} // namespace | ||
|
||
template <typename Op> | ||
LogicalResult | ||
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { | ||
auto opType = op.getType(); | ||
auto loc = op.getLoc(); | ||
auto vecType = opType.template dyn_cast<VectorType>(); | ||
|
||
if (!vecType) | ||
return failure(); | ||
if (!vecType.hasRank()) | ||
return failure(); | ||
auto shape = vecType.getShape(); | ||
// TODO: support multidimensional vectors | ||
if (shape.size() != 1) | ||
return failure(); | ||
|
||
Value result = rewriter.create<ConstantOp>( | ||
loc, DenseElementsAttr::get( | ||
vecType, FloatAttr::get(vecType.getElementType(), 0.0))); | ||
for (auto i = 0; i < shape.front(); ++i) { | ||
SmallVector<Value> operands; | ||
for (auto input : op->getOperands()) | ||
operands.push_back( | ||
rewriter.create<vector::ExtractElementOp>(loc, input, i)); | ||
Value scalarOp = | ||
rewriter.create<Op>(loc, vecType.getElementType(), operands); | ||
result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i); | ||
} | ||
rewriter.replaceOp(op, {result}); | ||
return success(); | ||
} | ||
|
||
template <typename Op> | ||
LogicalResult | ||
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, | ||
PatternRewriter &rewriter) const { | ||
auto module = op->template getParentOfType<ModuleOp>(); | ||
auto type = op.getType(); | ||
// TODO: Support Float16 by upcasting to Float32 | ||
if (!type.template isa<Float32Type, Float64Type>()) | ||
return failure(); | ||
|
||
auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; | ||
auto opFunc = module.template lookupSymbol<FuncOp>(name); | ||
// Forward declare function if it hasn't already been | ||
if (!opFunc) { | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(module.getBody()); | ||
auto opFunctionTy = FunctionType::get( | ||
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); | ||
opFunc = | ||
rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy); | ||
opFunc.setPrivate(); | ||
} | ||
assert(opFunc.getType().template cast<FunctionType>().getResults() == | ||
op->getResultTypes()); | ||
assert(opFunc.getType().template cast<FunctionType>().getInputs() == | ||
op->getOperandTypes()); | ||
|
||
rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands()); | ||
|
||
return success(); | ||
} | ||
|
||
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, | ||
PatternBenefit benefit) { | ||
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, | ||
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); | ||
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), | ||
"atan2f", "atan2", benefit); | ||
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(), | ||
"expm1f", "expm1", benefit); | ||
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf", | ||
"tanh", benefit); | ||
} | ||
|
||
namespace { | ||
struct ConvertMathToLibmPass | ||
: public ConvertMathToLibmBase<ConvertMathToLibmPass> { | ||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
void ConvertMathToLibmPass::runOnOperation() { | ||
auto module = getOperation(); | ||
|
||
RewritePatternSet patterns(&getContext()); | ||
populateMathToLibmConversionPatterns(patterns, /*benefit=*/1); | ||
|
||
ConversionTarget target(getContext()); | ||
target.addLegalDialect<BuiltinDialect, StandardOpsDialect, | ||
vector::VectorDialect>(); | ||
target.addIllegalDialect<math::MathDialect>(); | ||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() { | ||
return std::make_unique<ConvertMathToLibmPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s | ||
|
||
// CHECK-DAG: @expm1(f64) -> f64 | ||
// CHECK-DAG: @expm1f(f32) -> f32 | ||
// CHECK-DAG: @atan2(f64, f64) -> f64 | ||
// CHECK-DAG: @atan2f(f32, f32) -> f32 | ||
// CHECK-DAG: @tanh(f64) -> f64 | ||
// CHECK-DAG: @tanhf(f32) -> f32 | ||
|
||
// CHECK-LABEL: func @tanh_caller | ||
// CHECK-SAME: %[[FLOAT:.*]]: f32 | ||
// CHECK-SAME: %[[DOUBLE:.*]]: f64 | ||
func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { | ||
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32 | ||
%float_result = math.tanh %float : f32 | ||
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64 | ||
%double_result = math.tanh %double : f64 | ||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] | ||
return %float_result, %double_result : f32, f64 | ||
} | ||
|
||
|
||
// CHECK-LABEL: func @atan2_caller | ||
// CHECK-SAME: %[[FLOAT:.*]]: f32 | ||
// CHECK-SAME: %[[DOUBLE:.*]]: f64 | ||
func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) { | ||
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 | ||
%float_result = math.atan2 %float, %float : f32 | ||
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 | ||
%double_result = math.atan2 %double, %double : f64 | ||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] | ||
return %float_result, %double_result : f32, f64 | ||
} | ||
|
||
// CHECK-LABEL: func @expm1_caller | ||
// CHECK-SAME: %[[FLOAT:.*]]: f32 | ||
// CHECK-SAME: %[[DOUBLE:.*]]: f64 | ||
func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) { | ||
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32 | ||
%float_result = math.expm1 %float : f32 | ||
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64 | ||
%double_result = math.expm1 %double : f64 | ||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] | ||
return %float_result, %double_result : f32, f64 | ||
} | ||
|
||
func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { | ||
%float_result = math.expm1 %float : vector<2xf32> | ||
%double_result = math.expm1 %double : vector<2xf64> | ||
return %float_result, %double_result : vector<2xf32>, vector<2xf64> | ||
} | ||
// CHECK-LABEL: func @expm1_vec_caller( | ||
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, | ||
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { | ||
// CHECK: %[[CVF:.*]] = constant dense<0.000000e+00> : vector<2xf32> | ||
// CHECK: %[[CVD:.*]] = constant dense<0.000000e+00> : vector<2xf64> | ||
// CHECK: %[[C0:.*]] = constant 0 : i32 | ||
// CHECK: %[[C1:.*]] = constant 1 : i32 | ||
// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32> | ||
// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32 | ||
// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32> | ||
// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32> | ||
// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32 | ||
// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32> | ||
// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64> | ||
// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64 | ||
// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64> | ||
// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64> | ||
// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64 | ||
// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64> | ||
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> | ||
// CHECK: } | ||
|