Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SimplifyLibCalls] Merge sqrt into the power of exp #79146

Merged
merged 1 commit into from
Feb 6, 2024
Merged

Conversation

asi-sc
Copy link
Contributor

@asi-sc asi-sc commented Jan 23, 2024

Under fast-math flags it's possible to convert sqrt(exp(X)) into exp(X * 0.5). I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 23, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Anton Sidorenko (asi-sc)

Changes

Under fast-math flags it's possible to convert sqrt(exp(X)) into exp(X * 0.5). I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.


Full diff: https://github.com/llvm/llvm-project/pull/79146.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1)
  • (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+67)
  • (modified) llvm/test/Transforms/InstCombine/sqrt.ll (+107)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4c..1aad0b2988451cd 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -201,6 +201,7 @@ class LibCallSimplifier {
   Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
+  Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
   Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 52eef9ab58a4d92..047a793f349e7ef 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
   return Ret;
 }
 
+// sqrt(exp(X)) -> exp(X * 0.5)
+Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
+  if (!CI->isFast())
+    return nullptr;
+
+  Function *SqrtFn = CI->getCalledFunction();
+  CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
+  if (!Arg || !Arg->isFast() || !Arg->hasOneUse())
+    return nullptr;
+  Intrinsic::ID ArgID = Arg->getIntrinsicID();
+  LibFunc ArgLb = NotLibFunc;
+  TLI->getLibFunc(*Arg, ArgLb);
+
+  LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
+
+  if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
+    switch (SqrtLb) {
+    case LibFunc_sqrtf:
+      ExpLb = LibFunc_expf;
+      Exp2Lb = LibFunc_exp2f;
+      Exp10Lb = LibFunc_exp10f;
+      break;
+    case LibFunc_sqrt:
+      ExpLb = LibFunc_exp;
+      Exp2Lb = LibFunc_exp2;
+      Exp10Lb = LibFunc_exp10;
+      break;
+    case LibFunc_sqrtl:
+      ExpLb = LibFunc_expl;
+      Exp2Lb = LibFunc_exp2l;
+      Exp10Lb = LibFunc_exp10l;
+      break;
+    default:
+      return nullptr;
+    }
+  else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
+    if (CI->getType()->getScalarType()->isFloatTy()) {
+      ExpLb = LibFunc_expf;
+      Exp2Lb = LibFunc_exp2f;
+      Exp10Lb = LibFunc_exp10f;
+    } else if (CI->getType()->getScalarType()->isDoubleTy()) {
+      ExpLb = LibFunc_exp;
+      Exp2Lb = LibFunc_exp2;
+      Exp10Lb = LibFunc_exp10;
+    } else
+      return nullptr;
+  } else
+    return nullptr;
+
+  if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
+      ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
+    return nullptr;
+
+  IRBuilderBase::InsertPointGuard Guard(B);
+  B.SetInsertPoint(Arg);
+  auto *ExpOperand = Arg->getOperand(0);
+  auto *FMul =
+      B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
+                      CI, "merged.sqrt");
+
+  Arg->setOperand(0, FMul);
+  return Arg;
+}
+
 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
@@ -2553,6 +2617,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast())
     return Ret;
 
+  if (Value *Opt = mergeSqrtToExp(CI, B))
+    return Opt;
+
   Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
   if (!I || I->getOpcode() != Instruction::FMul || !I->isFast())
     return Ret;
diff --git a/llvm/test/Transforms/InstCombine/sqrt.ll b/llvm/test/Transforms/InstCombine/sqrt.ll
index 004df3e30c72a1e..b9cdbb6f6910c01 100644
--- a/llvm/test/Transforms/InstCombine/sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/sqrt.ll
@@ -88,7 +88,114 @@ define float @sqrt_call_fabs_f32(float %x) {
   ret float %sqrt
 }
 
+define double @sqrt_exp(double %x) {
+; CHECK-LABEL: @sqrt_exp(
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_2(
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp2(double %x) {
+; CHECK-LABEL: @sqrt_exp2(
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp2(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp2(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp10(double %x) {
+; CHECK-LABEL: @sqrt_exp10(
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @exp10(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @exp10(double %mul)
+  %res = call fast double @sqrt(double %e)
+  ret double %res
+}
+
+define double @sqrt_exp_nofast_1(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_1(
+; CHECK-NEXT:    [[MERGED_SQRT:%.*]] = fmul fast double [[X:%.*]], 5.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MERGED_SQRT]])
+; CHECK-NEXT:    ret double [[E]]
+;
+  %mul = fmul double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_2(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_2(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call fast double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call double @llvm.exp.f64(double %mul)
+  %res = call fast double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_nofast_3(double %x) {
+; CHECK-LABEL: @sqrt_exp_nofast_3(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], 1.000000e+01
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, 10.0
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
+; Negative test
+define double @sqrt_exp_noconst(double %x, double %y) {
+; CHECK-LABEL: @sqrt_exp_noconst(
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[E:%.*]] = call fast double @llvm.exp.f64(double [[MUL]])
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %mul = fmul fast double %x, %y
+  %e = call fast double @llvm.exp.f64(double %mul)
+  %res = call double @llvm.sqrt.f64(double %e)
+  ret double %res
+}
+
 declare i32 @foo(double)
 declare double @sqrt(double) readnone
 declare float @sqrtf(float)
 declare float @llvm.fabs.f32(float)
+declare double @llvm.exp.f64(double)
+declare double @llvm.sqrt.f64(double)
+declare double @exp(double)
+declare double @exp2(double)
+declare double @exp10(double)

Comment on lines +2556 to +2576
if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
switch (SqrtLb) {
case LibFunc_sqrtf:
ExpLb = LibFunc_expf;
Exp2Lb = LibFunc_exp2f;
Exp10Lb = LibFunc_exp10f;
break;
case LibFunc_sqrt:
ExpLb = LibFunc_exp;
Exp2Lb = LibFunc_exp2;
Exp10Lb = LibFunc_exp10;
break;
case LibFunc_sqrtl:
ExpLb = LibFunc_expl;
Exp2Lb = LibFunc_exp2l;
Exp10Lb = LibFunc_exp10l;
break;
default:
return nullptr;
}
else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a better way of handling intrinsic-or-libcall?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any other approaches to this. All code in this file seems to be written in the same way.

@@ -2538,6 +2538,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
return Ret;
}

// sqrt(exp(X)) -> exp(X * 0.5)
Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
if (!CI->isFast())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you just need reassoc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether just reassoc is enough. Don't we need afn (approximate calculations for functions) as well (or only afn)?

From langref for sqrt: When specified with the fast-math-flag ‘afn’, the result may be approximated using a less accurate calculation. https://llvm.org/docs/LangRef.html#llvm-sqrt-intrinsic
Or another example from clang users manual for -fapprox-func: For example, a pow(x, 0.25) may be replaced with sqrt(sqrt(x)). https://clang.llvm.org/docs/UsersManual.html#cmdoption-f-no-approx-func

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some experiments and this is net error reducing. This reassociation avoids some intermediate overflows to infinity and overall reduces worst case ulp. I think just reassoc is adequate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm, thanks for the research. I've changed the code to check only reassoc flag.

%res = call double @llvm.sqrt.f64(double %e)
ret double %res
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should reduce test flags. Also, can you add the tests with libcall exp + intrinsic sqrt and intrinsic exp + libcall sqrt? We shouldn't introduce new libcalls from intrinsics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've slightly simplified tests and added libcall + intrinsic tests. Fast-flags weren't modified. I'll adjust them when we decide on the correct set of flags that controls transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed fast flag to reassoc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the use of reassoc for this purpose, but it is consistent with existing practice, so I can't object to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcranmer-intel, in your oppinion, what is the correct set of flags? Since you said that it is consistent with the existing practice, I won't change them in this PR. But we may start a discussion at discourse and systematically change all places.

@asi-sc asi-sc requested a review from arsenm January 24, 2024 14:34
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use some vector intrinsic tests

sqrt(exp(X)) -> exp(X * 0.5). This is similar to the optimization existing in GCC.
@asi-sc
Copy link
Contributor Author

asi-sc commented Feb 5, 2024

Could use some vector intrinsic tests

Added.

Also rebased and squashed.

@arsenm arsenm merged commit 933247d into llvm:main Feb 6, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants