diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h index eb10545ee149e..de08b26173f6d 100644 --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -202,7 +202,7 @@ class LibCallSimplifier { Value *optimizeLog(CallInst *CI, IRBuilderBase &B); Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B); Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B); - Value *optimizeTan(CallInst *CI, IRBuilderBase &B); + Value *optimizeTrigInversionPairs(CallInst *CI, IRBuilderBase &B); // Wrapper for all floating point library call optimizations Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func, IRBuilderBase &B); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 52eef9ab58a4d..7a38016574b10 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -2607,13 +2607,16 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { return copyFlags(*CI, FabsCall); } -// TODO: Generalize to handle any trig function and its inverse. -Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { +Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI, + IRBuilderBase &B) { Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name)) + if (UnsafeFPShrink && + (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" || + Name == "asinh") && + hasFloatVersion(M, Name)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Value *Op1 = CI->getArgOperand(0); @@ -2626,16 +2629,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { return Ret; // tan(atan(x)) -> x - // tanf(atanf(x)) -> x - // tanl(atanl(x)) -> x + // atanh(tanh(x)) -> x + // sinh(asinh(x)) -> x + // asinh(sinh(x)) -> x + // cosh(acosh(x)) -> x LibFunc Func; Function *F = OpC->getCalledFunction(); if (F && TLI->getLibFunc(F->getName(), Func) && - isLibFuncEmittable(M, TLI, Func) && - ((Func == LibFunc_atan && Callee->getName() == "tan") || - (Func == LibFunc_atanf && Callee->getName() == "tanf") || - (Func == LibFunc_atanl && Callee->getName() == "tanl"))) - Ret = OpC->getArgOperand(0); + isLibFuncEmittable(M, TLI, Func)) { + LibFunc inverseFunc = llvm::StringSwitch(Callee->getName()) + .Case("tan", LibFunc_atan) + .Case("atanh", LibFunc_tanh) + .Case("sinh", LibFunc_asinh) + .Case("cosh", LibFunc_acosh) + .Case("tanf", LibFunc_atanf) + .Case("atanhf", LibFunc_tanhf) + .Case("sinhf", LibFunc_asinhf) + .Case("coshf", LibFunc_acoshf) + .Case("tanl", LibFunc_atanl) + .Case("atanhl", LibFunc_tanhl) + .Case("sinhl", LibFunc_asinhl) + .Case("coshl", LibFunc_acoshl) + .Case("asinh", LibFunc_sinh) + .Case("asinhf", LibFunc_sinhf) + .Case("asinhl", LibFunc_sinhl) + .Default(NumLibFuncs); // Used as error value + if (Func == inverseFunc) + Ret = OpC->getArgOperand(0); + } return Ret; } @@ -3628,7 +3649,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_tan: case LibFunc_tanf: case LibFunc_tanl: - return optimizeTan(CI, Builder); + case LibFunc_sinh: + case LibFunc_sinhf: + case LibFunc_sinhl: + case LibFunc_asinh: + case LibFunc_asinhf: + case LibFunc_asinhl: + case LibFunc_cosh: + case LibFunc_coshf: + case LibFunc_coshl: + case LibFunc_atanh: + case LibFunc_atanhf: + case LibFunc_atanhl: + return optimizeTrigInversionPairs(CI, Builder); case LibFunc_ceil: return replaceUnaryCall(CI, Builder, Intrinsic::ceil); case LibFunc_floor: @@ -3646,17 +3679,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_acos: case LibFunc_acosh: case LibFunc_asin: - case LibFunc_asinh: case LibFunc_atan: - case LibFunc_atanh: case LibFunc_cbrt: - case LibFunc_cosh: case LibFunc_exp: case LibFunc_exp10: case LibFunc_expm1: case LibFunc_cos: case LibFunc_sin: - case LibFunc_sinh: case LibFunc_tanh: if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) return optimizeUnaryDoubleFP(CI, Builder, TLI, true); diff --git a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll b/llvm/test/Transforms/InstCombine/tan-nofastmath.ll deleted file mode 100644 index 514ff4e40d618..0000000000000 --- a/llvm/test/Transforms/InstCombine/tan-nofastmath.ll +++ /dev/null @@ -1,17 +0,0 @@ -; RUN: opt < %s -passes=instcombine -S | FileCheck %s - -define float @mytan(float %x) { -entry: - %call = call float @atanf(float %x) - %call1 = call float @tanf(float %call) - ret float %call1 -} - -; CHECK-LABEL: define float @mytan( -; CHECK: %call = call float @atanf(float %x) -; CHECK-NEXT: %call1 = call float @tanf(float %call) -; CHECK-NEXT: ret float %call1 -; CHECK-NEXT: } - -declare float @tanf(float) -declare float @atanf(float) diff --git a/llvm/test/Transforms/InstCombine/tan.ll b/llvm/test/Transforms/InstCombine/tan.ll deleted file mode 100644 index 49f6e00e6d9ba..0000000000000 --- a/llvm/test/Transforms/InstCombine/tan.ll +++ /dev/null @@ -1,23 +0,0 @@ -; RUN: opt < %s -passes=instcombine -S | FileCheck %s - -define float @mytan(float %x) { - %call = call fast float @atanf(float %x) - %call1 = call fast float @tanf(float %call) - ret float %call1 -} - -; CHECK-LABEL: define float @mytan( -; CHECK: ret float %x - -define float @test2(ptr %fptr) { - %call1 = call fast float %fptr() - %tan = call fast float @tanf(float %call1) - ret float %tan -} - -; CHECK-LABEL: @test2 -; CHECK: tanf - -declare float @tanf(float) -declare float @atanf(float) - diff --git a/llvm/test/Transforms/InstCombine/trig.ll b/llvm/test/Transforms/InstCombine/trig.ll new file mode 100644 index 0000000000000..5dda1524396d4 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/trig.ll @@ -0,0 +1,140 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define float @tanAtanInverseFast(float %x) { +; CHECK-LABEL: define float @tanAtanInverseFast( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call fast float @atanf(float [[X]]) +; CHECK-NEXT: ret float [[X]] +; + %call = call fast float @atanf(float %x) + %call1 = call fast float @tanf(float %call) + ret float %call1 +} + +define float @atanhTanhInverseFast(float %x) { +; CHECK-LABEL: define float @atanhTanhInverseFast( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call fast float @tanhf(float [[X]]) +; CHECK-NEXT: ret float [[X]] +; + %call = call fast float @tanhf(float %x) + %call1 = call fast float @atanhf(float %call) + ret float %call1 +} + +define float @sinhAsinhInverseFast(float %x) { +; CHECK-LABEL: define float @sinhAsinhInverseFast( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call fast float @asinhf(float [[X]]) +; CHECK-NEXT: ret float [[X]] +; + %call = call fast float @asinhf(float %x) + %call1 = call fast float @sinhf(float %call) + ret float %call1 +} + +define float @asinhSinhInverseFast(float %x) { +; CHECK-LABEL: define float @asinhSinhInverseFast( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call fast float @sinhf(float [[X]]) +; CHECK-NEXT: ret float [[X]] +; + %call = call fast float @sinhf(float %x) + %call1 = call fast float @asinhf(float %call) + ret float %call1 +} + +define float @coshAcoshInverseFast(float %x) { +; CHECK-LABEL: define float @coshAcoshInverseFast( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call fast float @acoshf(float [[X]]) +; CHECK-NEXT: ret float [[X]] +; + %call = call fast float @acoshf(float %x) + %call1 = call fast float @coshf(float %call) + ret float %call1 +} + +define float @indirectTanCall(ptr %fptr) { +; CHECK-LABEL: define float @indirectTanCall( +; CHECK-SAME: ptr [[FPTR:%.*]]) { +; CHECK-NEXT: [[CALL1:%.*]] = call fast float [[FPTR]]() +; CHECK-NEXT: [[TAN:%.*]] = call fast float @tanf(float [[CALL1]]) +; CHECK-NEXT: ret float [[TAN]] +; + %call1 = call fast float %fptr() + %tan = call fast float @tanf(float %call1) + ret float %tan +} + +; No fast-math. + +define float @tanAtanInverse(float %x) { +; CHECK-LABEL: define float @tanAtanInverse( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call float @atanf(float [[X]]) +; CHECK-NEXT: [[CALL1:%.*]] = call float @tanf(float [[CALL]]) +; CHECK-NEXT: ret float [[CALL1]] +; + %call = call float @atanf(float %x) + %call1 = call float @tanf(float %call) + ret float %call1 +} + +define float @atanhTanhInverse(float %x) { +; CHECK-LABEL: define float @atanhTanhInverse( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call float @tanhf(float [[X]]) +; CHECK-NEXT: [[CALL1:%.*]] = call float @atanhf(float [[CALL]]) +; CHECK-NEXT: ret float [[CALL1]] +; + %call = call float @tanhf(float %x) + %call1 = call float @atanhf(float %call) + ret float %call1 +} + +define float @sinhAsinhInverse(float %x) { +; CHECK-LABEL: define float @sinhAsinhInverse( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call float @asinhf(float [[X]]) +; CHECK-NEXT: [[CALL1:%.*]] = call float @sinhf(float [[CALL]]) +; CHECK-NEXT: ret float [[CALL1]] +; + %call = call float @asinhf(float %x) + %call1 = call float @sinhf(float %call) + ret float %call1 +} + +define float @asinhSinhInverse(float %x) { +; CHECK-LABEL: define float @asinhSinhInverse( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call float @sinhf(float [[X]]) +; CHECK-NEXT: [[CALL1:%.*]] = call float @asinhf(float [[CALL]]) +; CHECK-NEXT: ret float [[CALL1]] +; + %call = call float @sinhf(float %x) + %call1 = call float @asinhf(float %call) + ret float %call1 +} + +define float @coshAcoshInverse(float %x) { +; CHECK-LABEL: define float @coshAcoshInverse( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[CALL:%.*]] = call float @acoshf(float [[X]]) +; CHECK-NEXT: [[CALL1:%.*]] = call float @coshf(float [[CALL]]) +; CHECK-NEXT: ret float [[CALL1]] +; + %call = call float @acoshf(float %x) + %call1 = call float @coshf(float %call) + ret float %call1 +} + +declare float @asinhf(float) +declare float @sinhf(float) +declare float @acoshf(float) +declare float @coshf(float) +declare float @tanhf(float) +declare float @atanhf(float) +declare float @tanf(float) +declare float @atanf(float)