diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h index eb10545ee149e..1aad0b2988451 100644 --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -201,6 +201,7 @@ class LibCallSimplifier { Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B); Value *optimizeLog(CallInst *CI, IRBuilderBase &B); Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B); + Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B); Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B); Value *optimizeTan(CallInst *CI, IRBuilderBase &B); // Wrapper for all floating point library call optimizations diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 52eef9ab58a4d..ab30cc2e249ad 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { return Ret; } +// sqrt(exp(X)) -> exp(X * 0.5) +Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) { + if (!CI->hasAllowReassoc()) + return nullptr; + + Function *SqrtFn = CI->getCalledFunction(); + CallInst *Arg = dyn_cast(CI->getArgOperand(0)); + if (!Arg || !Arg->hasAllowReassoc() || !Arg->hasOneUse()) + return nullptr; + Intrinsic::ID ArgID = Arg->getIntrinsicID(); + LibFunc ArgLb = NotLibFunc; + TLI->getLibFunc(*Arg, ArgLb); + + LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb; + + if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb)) + switch (SqrtLb) { + case LibFunc_sqrtf: + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + break; + case LibFunc_sqrt: + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + break; + case LibFunc_sqrtl: + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + break; + default: + return nullptr; + } + else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) { + if (CI->getType()->getScalarType()->isFloatTy()) { + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + } else if (CI->getType()->getScalarType()->isDoubleTy()) { + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + } else + return nullptr; + } else + return nullptr; + + if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb && + ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2) + return nullptr; + + IRBuilderBase::InsertPointGuard Guard(B); + B.SetInsertPoint(Arg); + auto *ExpOperand = Arg->getOperand(0); + auto *FMul = + B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5), + CI, "merged.sqrt"); + + Arg->setOperand(0, FMul); + return Arg; +} + Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); @@ -2550,6 +2614,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { Callee->getIntrinsicID() == Intrinsic::sqrt)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); + if (Value *Opt = mergeSqrtToExp(CI, B)) + return Opt; + if (!CI->isFast()) return Ret; diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll index 004df3e30c72a..f72fe5a6a5817 100644 --- a/llvm/test/Transforms/InstCombine/sqrt.ll +++ b/llvm/test/Transforms/InstCombine/sqrt.ll @@ -88,7 +88,127 @@ define float @sqrt_call_fabs_f32(float %x) { ret float %sqrt } +define double @sqrt_exp(double %x) { +; CHECK-LABEL: @sqrt_exp( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @llvm.exp.f64(double %x) + %res = call reassoc double @llvm.sqrt.f64(double %e) + ret double %res +} + +define double @sqrt_exp_2(double %x) { +; CHECK-LABEL: @sqrt_exp_2( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @exp(double %x) + %res = call reassoc double @sqrt(double %e) + ret double %res +} + +define double @sqrt_exp2(double %x) { +; CHECK-LABEL: @sqrt_exp2( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp2(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @exp2(double %x) + %res = call reassoc double @sqrt(double %e) + ret double %res +} + +define double @sqrt_exp10(double %x) { +; CHECK-LABEL: @sqrt_exp10( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp10(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @exp10(double %x) + %res = call reassoc double @sqrt(double %e) + ret double %res +} + +; Negative test +define double @sqrt_exp_nofast_1(double %x) { +; CHECK-LABEL: @sqrt_exp_nofast_1( +; CHECK-NEXT: [[E:%.*]] = call double @llvm.exp.f64(double [[X:%.*]]) +; CHECK-NEXT: [[RES:%.*]] = call reassoc double @llvm.sqrt.f64(double [[E]]) +; CHECK-NEXT: ret double [[RES]] +; + %e = call double @llvm.exp.f64(double %x) + %res = call reassoc double @llvm.sqrt.f64(double %e) + ret double %res +} + +; Negative test +define double @sqrt_exp_nofast_2(double %x) { +; CHECK-LABEL: @sqrt_exp_nofast_2( +; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[X:%.*]]) +; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]]) +; CHECK-NEXT: ret double [[RES]] +; + %e = call reassoc double @llvm.exp.f64(double %x) + %res = call double @llvm.sqrt.f64(double %e) + ret double %res +} + +define double @sqrt_exp_merge_constant(double %x, double %y) { +; CHECK-LABEL: @sqrt_exp_merge_constant( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc nsz double [[X:%.*]], 5.000000e+00 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %mul = fmul reassoc nsz double %x, 10.0 + %e = call reassoc double @llvm.exp.f64(double %mul) + %res = call reassoc nsz double @llvm.sqrt.f64(double %e) + ret double %res +} + +define double @sqrt_exp_intr_and_libcall(double %x) { +; CHECK-LABEL: @sqrt_exp_intr_and_libcall( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @exp(double %x) + %res = call reassoc double @llvm.sqrt.f64(double %e) + ret double %res +} + +define double @sqrt_exp_intr_and_libcall_2(double %x) { +; CHECK-LABEL: @sqrt_exp_intr_and_libcall_2( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01 +; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]]) +; CHECK-NEXT: ret double [[E]] +; + %e = call reassoc double @llvm.exp.f64(double %x) + %res = call reassoc double @sqrt(double %e) + ret double %res +} + +define <2 x float> @sqrt_exp_vec(<2 x float> %x) { +; CHECK-LABEL: @sqrt_exp_vec( +; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc <2 x float> [[X:%.*]], +; CHECK-NEXT: [[E:%.*]] = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> [[MERGED_SQRT]]) +; CHECK-NEXT: ret <2 x float> [[E]] +; + %e = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> %x) + %res = call reassoc <2 x float> @llvm.sqrt.v2f32(<2 x float> %e) + ret <2 x float> %res +} + declare i32 @foo(double) declare double @sqrt(double) readnone declare float @sqrtf(float) declare float @llvm.fabs.f32(float) +declare double @llvm.exp.f64(double) +declare double @llvm.sqrt.f64(double) +declare double @exp(double) +declare double @exp2(double) +declare double @exp10(double) +declare <2 x float> @llvm.exp.v2f32(<2 x float>) +declare <2 x float> @llvm.sqrt.v2f32(<2 x float>)