Skip to content

Commit

Permalink
[SimplifyLibCalls] Tranform log(pow(x, y)) -> y*log(x).
Browse files Browse the repository at this point in the history
This one is enabled only under -ffast-math. There are cases where the
difference between the value computed and the correct value is huge
even for ffast-math, e.g. as Steven pointed out:

x = -1, y = -4
log(pow(-1), 4) = 0
4*log(-1) = NaN

I checked what GCC does and apparently they do the same optimization
(which result in the dramatic difference). Future work might try to
make this (slightly) less worse.

Differential Revision:	http://reviews.llvm.org/D14400

llvm-svn: 254263
  • Loading branch information
dcci committed Nov 29, 2015
1 parent a5c0493 commit b8b7133
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 5 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
Expand Up @@ -132,6 +132,7 @@ class LibCallSimplifier {
Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B);
Value *optimizeLog(CallInst *CI, IRBuilder<> &B);
Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
Value *optimizeTan(CallInst *CI, IRBuilder<> &B);
Expand Down
55 changes: 50 additions & 5 deletions llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Expand Up @@ -1284,6 +1284,48 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
return B.CreateSelect(Cmp, Op0, Op1);
}

Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();
Value *Ret = nullptr;
StringRef Name = Callee->getName();
if (UnsafeFPShrink && hasFloatVersion(Name))
Ret = optimizeUnaryDoubleFP(CI, B, true);
FunctionType *FT = Callee->getFunctionType();

// Just make sure this has 1 argument of FP type, which matches the
// result type.
if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) ||
!FT->getParamType(0)->isFloatingPointTy())
return Ret;

if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
return Ret;
Value *Op1 = CI->getArgOperand(0);
auto *OpC = dyn_cast<CallInst>(Op1);
if (!OpC)
return Ret;

// log(pow(x,y)) -> y*log(x)
// This is only applicable to log, log2, log10.
if (Name != "log" && Name != "log2" && Name != "log10")
return Ret;

IRBuilder<>::FastMathFlagGuard Guard(B);
FastMathFlags FMF;
FMF.setUnsafeAlgebra();
B.SetFastMathFlags(FMF);

LibFunc::Func Func;
Function *F = OpC->getCalledFunction();
StringRef FuncName = F->getName();
if ((TLI->getLibFunc(FuncName, Func) && TLI->has(Func) &&
Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)
return B.CreateFMul(OpC->getArgOperand(1),
EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
Callee->getAttributes()), "mul");
return Ret;
}

Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
Function *Callee = CI->getCalledFunction();

Expand Down Expand Up @@ -2088,6 +2130,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
return optimizeExp2(CI, Builder);
case Intrinsic::fabs:
return optimizeFabs(CI, Builder);
case Intrinsic::log:
return optimizeLog(CI, Builder);
case Intrinsic::sqrt:
return optimizeSqrt(CI, Builder);
default:
Expand Down Expand Up @@ -2170,6 +2214,12 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
return optimizeFWrite(CI, Builder);
case LibFunc::fputs:
return optimizeFPuts(CI, Builder);
case LibFunc::log:
case LibFunc::log10:
case LibFunc::log1p:
case LibFunc::log2:
case LibFunc::logb:
return optimizeLog(CI, Builder);
case LibFunc::puts:
return optimizePuts(CI, Builder);
case LibFunc::tan:
Expand Down Expand Up @@ -2203,11 +2253,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
case LibFunc::exp:
case LibFunc::exp10:
case LibFunc::expm1:
case LibFunc::log:
case LibFunc::log10:
case LibFunc::log1p:
case LibFunc::log2:
case LibFunc::logb:
case LibFunc::sin:
case LibFunc::sinh:
case LibFunc::tanh:
Expand Down
17 changes: 17 additions & 0 deletions llvm/test/Transforms/InstCombine/log-pow-nofastmath.ll
@@ -0,0 +1,17 @@
; RUN: opt < %s -instcombine -S | FileCheck %s

define double @mylog(double %x, double %y) #0 {
entry:
%pow = call double @llvm.pow.f64(double %x, double %y)
%call = call double @log(double %pow) #0
ret double %call
}

; CHECK-LABEL: define double @mylog(
; CHECK: %pow = call double @llvm.pow.f64(double %x, double %y)
; CHECK: %call = call double @log(double %pow)
; CHECK: ret double %call
; CHECK: }

declare double @log(double) #0
declare double @llvm.pow.f64(double, double)
19 changes: 19 additions & 0 deletions llvm/test/Transforms/InstCombine/log-pow.ll
@@ -0,0 +1,19 @@
; RUN: opt < %s -instcombine -S | FileCheck %s

define double @mylog(double %x, double %y) #0 {
entry:
%pow = call double @llvm.pow.f64(double %x, double %y)
%call = call double @log(double %pow) #0
ret double %call
}

; CHECK-LABEL: define double @mylog(
; CHECK: %log = call double @log(double %x) #0
; CHECK: %mul = fmul fast double %log, %y
; CHECK: ret double %mul
; CHECK: }

declare double @log(double) #0
declare double @llvm.pow.f64(double, double)

attributes #0 = { "unsafe-fp-math"="true" }

0 comments on commit b8b7133

Please sign in to comment.