-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SimplifyLibCalls] Merge sqrt into the power of exp #79146
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Anton Sidorenko (asi-sc) ChangesUnder fast-math flags it's possible to convert Full diff: https://github.com/llvm/llvm-project/pull/79146.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4c..1aad0b2988451cd 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -201,6 +201,7 @@ class LibCallSimplifier {
Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
+ Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
// Wrapper for all floating point library call optimizations
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d92..047a793f349e7ef 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
return Ret;
}
+// sqrt(exp(X)) -> exp(X * 0.5)
+Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
+ if (!CI->isFast())
+ return nullptr;
+
+ Function *SqrtFn = CI->getCalledFunction();
+ CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
+ if (!Arg || !Arg->isFast() || !Arg->hasOneUse())
+ return nullptr;
+ Intrinsic::ID ArgID = Arg->getIntrinsicID();
+ LibFunc ArgLb = NotLibFunc;
+ TLI->getLibFunc(*Arg, ArgLb);
+
+ LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
+
+ if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
+ switch (SqrtLb) {
+ case LibFunc_sqrtf:
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ break;
+ case LibFunc_sqrt:
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ break;
+ case LibFunc_sqrtl:
+ ExpLb = LibFunc_expl;
+ Exp2Lb = LibFunc_exp2l;
+ Exp10Lb = LibFunc_exp10l;
+ break;
+ default:
+ return nullptr;
+ }
+ else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
+ if (CI->getType()->getScalarType()->isFloatTy()) {
+ ExpLb = LibFunc_expf;
+ Exp2Lb = LibFunc_exp2f;
+ Exp10Lb = LibFunc_exp10f;
+ } else if (CI->getType()->getScalarType()->isDoubleTy()) {
+ ExpLb = LibFunc_exp;
+ Exp2Lb = LibFunc_exp2;
+ Exp10Lb = LibFunc_exp10;
+ } else
+ return nullptr;
+ } else
+ return nullptr;
+
+ if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
+ ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
+ return nullptr;
+
+ IRBuilderBase::InsertPointGuard Guard(B);
+ B.SetInsertPoint(Arg);
+ auto *ExpOperand = Arg->getOperand(0);
+ auto *FMul =
+ B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
+ CI, "merged.sqrt");
+
+ Arg->setOperand(0, FMul);
+ return Arg;
+}
+
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
Module *M = CI->getModule();
Function *Callee = CI->getCalledFunction();
@@ -2553,6 +2617,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
if (!CI->isFast())
return Ret;
+ if (Value *Opt = mergeSqrtToExp(CI, B))
+ return Opt;
+
Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
if (!I || I->getOpcode() != Instruction::FMul || !I->isFast())
return Ret;
diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll
index 004df3e30c72a1e..b9cdbb6f6910c01 100644
--- a/llvm/test/Transforms/InstCombine/sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/sqrt.ll
@@ -88,7 +88,114 @@ define float @sqrt_call_fabs_f32(float %x) {
ret float %sqrt
}
+define double @sqrt_exp(double %x) {
+; CHECK-LABEL: @sqrt_exp(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_2(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp2(double %x) {
+; CHECK-LABEL: @sqrt_exp2(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp2(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp2(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp10(double %x) {
+; CHECK-LABEL: @sqrt_exp10(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @exp10(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @exp10(double %mul)
+ %res = call fast double @sqrt(double %e)
+ ret double %res
+}
+
+define double @sqrt_exp_nofast_1(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_1(
+; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT: ret double [[E]]
+;
+ %mul = fmul double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_2(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT: [[E:%.*]] = call double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call double @llvm.exp.f64(double %mul)
+ %res = call fast double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_3(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_3(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, 10.0
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_noconst(double %x, double %y) {
+; CHECK-LABEL: @sqrt_exp_noconst(
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT: ret double [[RES]]
+;
+ %mul = fmul fast double %x, %y
+ %e = call fast double @llvm.exp.f64(double %mul)
+ %res = call double @llvm.sqrt.f64(double %e)
+ ret double %res
+}
+
declare i32 @foo(double)
declare double @sqrt(double) readnone
declare float @sqrtf(float)
declare float @llvm.fabs.f32(float)
+declare double @llvm.exp.f64(double)
+declare double @llvm.sqrt.f64(double)
+declare double @exp(double)
+declare double @exp2(double)
+declare double @exp10(double)
|
if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb)) | ||
switch (SqrtLb) { | ||
case LibFunc_sqrtf: | ||
ExpLb = LibFunc_expf; | ||
Exp2Lb = LibFunc_exp2f; | ||
Exp10Lb = LibFunc_exp10f; | ||
break; | ||
case LibFunc_sqrt: | ||
ExpLb = LibFunc_exp; | ||
Exp2Lb = LibFunc_exp2; | ||
Exp10Lb = LibFunc_exp10; | ||
break; | ||
case LibFunc_sqrtl: | ||
ExpLb = LibFunc_expl; | ||
Exp2Lb = LibFunc_exp2l; | ||
Exp10Lb = LibFunc_exp10l; | ||
break; | ||
default: | ||
return nullptr; | ||
} | ||
else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not have a better way of handling intrinsic-or-libcall?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find any other approaches to this. All code in this file seems to be written in the same way.
@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { | |||
return Ret; | |||
} | |||
|
|||
// sqrt(exp(X)) -> exp(X * 0.5) | |||
Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) { | |||
if (!CI->isFast()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you just need reassoc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether just reassoc is enough. Don't we need afn (approximate calculations for functions) as well (or only afn)?
From langref for sqrt: When specified with the fast-math-flag ‘afn’, the result may be approximated using a less accurate calculation
. https://llvm.org/docs/LangRef.html#llvm-sqrt-intrinsic
Or another example from clang users manual for -fapprox-func
: For example, a pow(x, 0.25) may be replaced with sqrt(sqrt(x))
. https://clang.llvm.org/docs/UsersManual.html#cmdoption-f-no-approx-func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some experiments and this is net error reducing. This reassociation avoids some intermediate overflows to infinity and overall reduces worst case ulp. I think just reassoc is adequate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arsenm, thanks for the research. I've changed the code to check only reassoc flag.
%res = call double @llvm.sqrt.f64(double %e) | ||
ret double %res | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should reduce test flags. Also, can you add the tests with libcall exp + intrinsic sqrt and intrinsic exp + libcall sqrt? We shouldn't introduce new libcalls from intrinsics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've slightly simplified tests and added libcall + intrinsic tests. Fast-flags weren't modified. I'll adjust them when we decide on the correct set of flags that controls transformation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed fast flag to reassoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the use of reassoc
for this purpose, but it is consistent with existing practice, so I can't object to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcranmer-intel, in your oppinion, what is the correct set of flags? Since you said that it is consistent with the existing practice, I won't change them in this PR. But we may start a discussion at discourse and systematically change all places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could use some vector intrinsic tests
sqrt(exp(X)) -> exp(X * 0.5). This is similar to the optimization existing in GCC.
Added. Also rebased and squashed. |
Under fast-math flags it's possible to convert
sqrt(exp(X))
intoexp(X * 0.5)
. I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.