Skip to content

Commit

Permalink
[mlir] Add patterns to lower Math operations to LLVM based libm calls.
Browse files Browse the repository at this point in the history
Some Math operations do not have an equivalent in LLVM. In these cases,
allow a low priority fallback of calling the libm functions. This is to
give functionality and is not a performant option.

Differential Revision: https://reviews.llvm.org/D100367
  • Loading branch information
tpopp committed Apr 20, 2021
1 parent 2ea6ed9 commit 34810e1
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- MathToLibm.h - Utils to convert from the complex dialect --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_

#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
template <typename T>
class OperationPass;

/// Populate the given list with patterns that convert from Math to Libm calls.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit);

/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,19 @@ def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//

def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
let summary = "Convert Math dialect to libm calls";
let description = [{
This pass converts supported Math ops to libm calls.
}];
let constructor = "mlir::createConvertMathToLibmPass()";
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
}

//===----------------------------------------------------------------------===//
// OpenMPToLLVM
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_subdirectory(GPUToVulkan)
add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(MathToLibm)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(SCFToGPU)
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Conversion/MathToLibm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_mlir_conversion_library(MLIRMathToLibm
MathToLibm.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLibm

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRMath
MLIRStandardOpsTransforms
)
147 changes: 147 additions & 0 deletions mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToLibm/MathToLibm.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;

namespace {
// Pattern to convert vector operations to scalar operations. This is needed as
// libm calls require scalars.
template <typename Op>
struct VecOpToScalarOp : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;

LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
};
// Pattern to convert scalar math operations to calls to libm functions.
// Additionally the libm function signatures are declared.
template <typename Op>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
StringRef doubleFunc, PatternBenefit benefit)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};

LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;

private:
std::string floatFunc, doubleFunc;
};
} // namespace

template <typename Op>
LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto opType = op.getType();
auto loc = op.getLoc();
auto vecType = opType.template dyn_cast<VectorType>();

if (!vecType)
return failure();
if (!vecType.hasRank())
return failure();
auto shape = vecType.getShape();
// TODO: support multidimensional vectors
if (shape.size() != 1)
return failure();

Value result = rewriter.create<ConstantOp>(
loc, DenseElementsAttr::get(
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
for (auto i = 0; i < shape.front(); ++i) {
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(
rewriter.create<vector::ExtractElementOp>(loc, input, i));
Value scalarOp =
rewriter.create<Op>(loc, vecType.getElementType(), operands);
result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
}
rewriter.replaceOp(op, {result});
return success();
}

template <typename Op>
LogicalResult
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
PatternRewriter &rewriter) const {
auto module = op->template getParentOfType<ModuleOp>();
auto type = op.getType();
// TODO: Support Float16 by upcasting to Float32
if (!type.template isa<Float32Type, Float64Type>())
return failure();

auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
auto opFunc = module.template lookupSymbol<FuncOp>(name);
// Forward declare function if it hasn't already been
if (!opFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
opFunc =
rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
opFunc.setPrivate();
}
assert(opFunc.getType().template cast<FunctionType>().getResults() ==
op->getResultTypes());
assert(opFunc.getType().template cast<FunctionType>().getInputs() ==
op->getOperandTypes());

rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands());

return success();
}

void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
"atan2f", "atan2", benefit);
patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
"expm1f", "expm1", benefit);
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
"tanh", benefit);
}

namespace {
struct ConvertMathToLibmPass
: public ConvertMathToLibmBase<ConvertMathToLibmPass> {
void runOnOperation() override;
};
} // namespace

void ConvertMathToLibmPass::runOnOperation() {
auto module = getOperation();

RewritePatternSet patterns(&getContext());
populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);

ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, StandardOpsDialect,
vector::VectorDialect>();
target.addIllegalDialect<math::MathDialect>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
return std::make_unique<ConvertMathToLibmPass>();
}
73 changes: 73 additions & 0 deletions mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s

// CHECK-DAG: @expm1(f64) -> f64
// CHECK-DAG: @expm1f(f32) -> f32
// CHECK-DAG: @atan2(f64, f64) -> f64
// CHECK-DAG: @atan2f(f32, f32) -> f32
// CHECK-DAG: @tanh(f64) -> f64
// CHECK-DAG: @tanhf(f32) -> f32

// CHECK-LABEL: func @tanh_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.tanh %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.tanh %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}


// CHECK-LABEL: func @atan2_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
%float_result = math.atan2 %float, %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
%double_result = math.atan2 %double, %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}

// CHECK-LABEL: func @expm1_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32
%float_result = math.expm1 %float : f32
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.expm1 %double : f64
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}

func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
%float_result = math.expm1 %float : vector<2xf32>
%double_result = math.expm1 %double : vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
// CHECK-LABEL: func @expm1_vec_caller(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
// CHECK: %[[CVF:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[CVD:.*]] = constant dense<0.000000e+00> : vector<2xf64>
// CHECK: %[[C0:.*]] = constant 0 : i32
// CHECK: %[[C1:.*]] = constant 1 : i32
// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32>
// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32>
// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64>
// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64>
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
// CHECK: }

0 comments on commit 34810e1

Please sign in to comment.