Skip to content

Commit

Permalink
[mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there…
Browse files Browse the repository at this point in the history
… is no reasonable TanH representation in LLVM.

Summary: The current ConvertStandardToLLVM phase lowers the standard TanHOp to function calls to external tanh symbols. However, this leads to misunderstandings since these external symbols are not defined anywhere. This commit removes the TanHOp lowering functionality from ConvertStandardToLLVM, adapts the LowerGpuOpsToNVVMOps and LowerGpuOpsToROCDLOps passes and adjusts the affected test cases.

Reviewers: mravishankar, herhut

Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75509
  • Loading branch information
dfki-mako committed Mar 25, 2020
1 parent 69def20 commit 2b529a3
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 114 deletions.
18 changes: 0 additions & 18 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Expand Up @@ -95,24 +95,6 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
const std::string f64Func;
};

namespace gpu {
/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate
/// returns false for calls to the provided intrinsics and true otherwise.
inline std::function<bool(Operation *)>
filterIllegalLLVMIntrinsics(ArrayRef<StringRef> intrinsics, MLIRContext *ctx) {
SmallVector<StringRef, 4> illegalIds(intrinsics.begin(), intrinsics.end());
return [illegalIds](Operation *op) -> bool {
LLVM::CallOp callOp = dyn_cast<LLVM::CallOp>(op);
if (!callOp || !callOp.callee())
return true;
StringRef callee = callOp.callee().getValue();
return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) {
return callee.equals(intrinsic);
});
};
}
} // namespace gpu

} // namespace mlir

#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
2 changes: 0 additions & 2 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Expand Up @@ -279,8 +279,6 @@ class LowerGpuOpsToNVVMOpsPass
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
target.addIllegalOp<FuncOp>();
target.addLegalDialect<NVVM::NVVMDialect>();
target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Expand Up @@ -71,8 +71,6 @@ class LowerGpuOpsToROCDLOpsPass
target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
target.addDynamicallyLegalOp<LLVM::CallOp>(
gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
target.addIllegalOp<FuncOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
Expand Down
52 changes: 1 addition & 51 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Expand Up @@ -1737,56 +1737,6 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
};

// A `tanh` is converted into a call to the `tanh` function.
struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

using LLVMFuncOpT = LLVM::LLVMFuncOp;
using LLVMTypeT = LLVM::LLVMType;

OperandAdaptor<TanhOp> transformed(operands);
LLVMTypeT operandType =
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();

if (!operandType)
return failure();

std::string functionName;
if (operandType.isFloatTy())
functionName = "tanhf";
else if (operandType.isDoubleTy())
functionName = "tanh";
else
return failure();

// Get a reference to the tanh function, inserting it if necessary.
Operation *tanhFunc =
SymbolTable::lookupNearestSymbolFrom(op, functionName);

LLVMFuncOpT tanhLLVMFunc;
if (tanhFunc) {
tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
} else {
PatternRewriter::InsertionGuard insertGuard(rewriter);
auto module = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPointToStart(module.getBody());
tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
module.getLoc(), functionName,
LLVMTypeT::getFunctionTy(operandType, operandType,
/*isVarArg=*/false));
}

rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
transformed.operand());
return success();
}
};

struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;

Expand Down Expand Up @@ -2833,7 +2783,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
TanhOpLowering,
TruncateIOpLowering,
UnsignedDivIOpLowering,
UnsignedRemIOpLowering,
Expand Down Expand Up @@ -3022,6 +2971,7 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
this->addLegalDialect<LLVM::LLVMDialect>();
this->addIllegalOp<LLVM::DialectCastOp>();
this->addIllegalOp<TanhOp>();
}

std::unique_ptr<OpPassBase<ModuleOp>>
Expand Down
62 changes: 21 additions & 41 deletions mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Expand Up @@ -407,43 +407,39 @@ func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
// CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32
%2 = cmpi "slt", %arg2, %1 : i32
// CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32
%4 = divi_signed %arg2, %arg3 : i32
%3 = divi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : !llvm.i32
%5 = divi_unsigned %arg2, %arg3 : i32
%4 = divi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : !llvm.i32
%6 = remi_signed %arg2, %arg3 : i32
%5 = remi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : !llvm.i32
%7 = remi_unsigned %arg2, %arg3 : i32
%6 = remi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32
%8 = select %2, %arg2, %arg3 : i32
%7 = select %2, %arg2, %arg3 : i32
// CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : !llvm.float
%9 = divf %arg0, %arg1 : f32
%8 = divf %arg0, %arg1 : f32
// CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : !llvm.float
%10 = remf %arg0, %arg1 : f32
%9 = remf %arg0, %arg1 : f32
// CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32
%11 = and %arg2, %arg3 : i32
%10 = and %arg2, %arg3 : i32
// CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32
%12 = or %arg2, %arg3 : i32
%11 = or %arg2, %arg3 : i32
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
%13 = xor %arg2, %arg3 : i32
%12 = xor %arg2, %arg3 : i32
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
%14 = std.exp %arg0 : f32
// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float
%15 = std.tanh %arg0 : f32
// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
%16 = constant 7.9e-01 : f64
// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double
%17 = std.tanh %16 : f64
// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32
%18 = shift_left %arg2, %arg3 : i32
// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32
%19 = shift_right_signed %arg2, %arg3 : i32
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
%20 = shift_right_unsigned %arg2, %arg3 : i32
%13 = std.exp %arg0 : f32
// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
%14 = constant 7.9e-01 : f64
// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
%15 = shift_left %arg2, %arg3 : i32
// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
%16 = shift_right_signed %arg2, %arg3 : i32
// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
%17 = shift_right_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
%21 = std.sqrt %arg0 : f32
%18 = std.sqrt %arg0 : f32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
%22 = std.sqrt %arg4 : f64
%19 = std.sqrt %arg4 : f64
return %0, %4 : f32, i32
}

Expand Down Expand Up @@ -853,22 +849,6 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1)

// -----

module {
func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () {
%f0 = std.tanh %f : f32
%f1 = std.tanh %f0 : f32
%lf0 = std.tanh %lf : f64
%lf1 = std.tanh %lf0 : f64
return
}
// CHECK: module {
// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double
// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
}

// -----

// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
Expand Down

0 comments on commit 2b529a3

Please sign in to comment.