Skip to content

Commit

Permalink
[mlir] Add support for lowering tanh to LLVMIR.
Browse files Browse the repository at this point in the history
Summary:
Fixed build of D81618

Add a pattern for expanding tanh op into exp form.
A `tanh` is expanded into:
   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0.

Differential Revision: https://reviews.llvm.org/D82040
  • Loading branch information
hanhanW committed Jun 18, 2020
1 parent c835b5c commit 9cb1029
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
Expand Up @@ -20,10 +20,15 @@
namespace mlir {

class Pass;
class MLIRContext;
class OwningRewritePatternList;

/// Creates an instance of the ExpandAtomic pass.
std::unique_ptr<Pass> createExpandAtomicPass();

void populateExpandTanhPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);

} // end namespace mlir

#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
1 change: 1 addition & 0 deletions mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
ExpandAtomic.cpp
ExpandTanh.cpp
FuncConversions.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
70 changes: 70 additions & 0 deletions mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp
@@ -0,0 +1,70 @@
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of tanh op.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {

/// Expands tanh op into
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
struct TanhOpConverter : public OpRewritePattern<TanhOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TanhOp op,
PatternRewriter &rewriter) const final {
auto floatType = op.operand().getType();
Location loc = op.getLoc();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
Value one = rewriter.create<ConstantOp>(loc, floatOne);
Value two = rewriter.create<ConstantOp>(loc, floatTwo);
Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);

// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
Value exp2x = rewriter.create<ExpOp>(loc, negDoubledX);
Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);

// Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
exp2x = rewriter.create<ExpOp>(loc, doubledX);
dividend = rewriter.create<SubFOp>(loc, exp2x, one);
divisor = rewriter.create<AddFOp>(loc, exp2x, one);
Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);

// tanh(x) = x >= 0 ? positiveRes : negativeRes
auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
Value zero = rewriter.create<ConstantOp>(loc, floatZero);
Value cmpRes =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
return success();
}
};
} // namespace

void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<TanhOpConverter>(ctx);
}
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Standard/expand-tanh.mlir
@@ -0,0 +1,23 @@
// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s

// CHECK-LABEL: func @tanh
func @tanh(%arg: f32) -> f32 {
%res = tanh %arg : f32
return %res : f32
}
// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK-DAG: %[[ONE:.+]] = constant 1.000000e+00 : f32
// CHECK-DAG: %[[TWO:.+]] = constant 2.000000e+00 : f32
// CHECK: %[[DOUBLEDX:.+]] = mulf %arg0, %[[TWO]] : f32
// CHECK: %[[NEGDOUBLEDX:.+]] = negf %[[DOUBLEDX]] : f32
// CHECK: %[[EXP1:.+]] = exp %[[NEGDOUBLEDX]] : f32
// CHECK: %[[DIVIDEND1:.+]] = subf %[[ONE]], %[[EXP1]] : f32
// CHECK: %[[DIVISOR1:.+]] = addf %[[ONE]], %[[EXP1]] : f32
// CHECK: %[[RES1:.+]] = divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
// CHECK: %[[EXP2:.+]] = exp %[[DOUBLEDX]] : f32
// CHECK: %[[DIVIDEND2:.+]] = subf %[[EXP2]], %[[ONE]] : f32
// CHECK: %[[DIVISOR2:.+]] = addf %[[EXP2]], %[[ONE]] : f32
// CHECK: %[[RES2:.+]] = divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
// CHECK: %[[COND:.+]] = cmpf "oge", %arg0, %[[ZERO]] : f32
// CHECK: %[[RESULT:.+]] = select %[[COND]], %[[RES1]], %[[RES2]] : f32
// CHECK: return %[[RESULT]]
1 change: 1 addition & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Expand Up @@ -2,6 +2,7 @@
add_mlir_library(MLIRTestTransforms
TestAllReduceLowering.cpp
TestBufferPlacement.cpp
TestExpandTanh.cpp
TestCallGraph.cpp
TestConstantFold.cpp
TestConvertGPUKernelToCubin.cpp
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -0,0 +1,37 @@
//===- TestExpandTanh.cpp - Test expand tanh op into exp form ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for expanding tanh.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
struct TestExpandTanhPass
: public PassWrapper<TestExpandTanhPass, FunctionPass> {
void runOnFunction() override;
};
} // end anonymous namespace

void TestExpandTanhPass::runOnFunction() {
OwningRewritePatternList patterns;
populateExpandTanhPattern(patterns, &getContext());
applyPatternsAndFoldGreedily(getOperation(), patterns);
}

namespace mlir {
void registerTestExpandTanhPass() {
PassRegistration<TestExpandTanhPass> pass("test-expand-tanh",
"Test expanding tanh");
}
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Expand Up @@ -48,6 +48,7 @@ void registerTestConstantFold();
void registerTestConvertGPUKernelToCubinPass();
void registerTestConvertGPUKernelToHsacoPass();
void registerTestDominancePass();
void registerTestExpandTanhPass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
void registerTestLinalgHoisting();
Expand Down Expand Up @@ -122,6 +123,7 @@ void registerTestPasses() {
registerTestBufferPlacementPreparationPass();
registerTestDominancePass();
registerTestFunc();
registerTestExpandTanhPass();
registerTestGpuMemoryPromotionPass();
registerTestLinalgHoisting();
registerTestLinalgTransforms();
Expand Down

0 comments on commit 9cb1029

Please sign in to comment.