Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transforms] Expand optimizeTan to fold more inverse trig pairs #77799

Merged
merged 2 commits into from
Feb 6, 2024

Conversation

AtariDreams
Copy link
Contributor

@AtariDreams AtariDreams commented Jan 11, 2024

optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x. For example, asin(sin(x)) does not fold to x if x is over 2pi.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

Changes

It has been renamed to optimizeTrig as a result. Use a map to map functions to their inverses.


Full diff: https://github.com/llvm/llvm-project/pull/77799.diff

2 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h (+1-1)
  • (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+41-18)
diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
index eb10545ee149e4..b1b8b9a5b6ad6a 100644
--- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
+++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h
@@ -202,7 +202,7 @@ class LibCallSimplifier {
   Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
   Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
-  Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
+  Value *optimizeTrig(CallInst *CI, IRBuilderBase &B);
   // Wrapper for all floating point library call optimizations
   Value *optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func,
                                       IRBuilderBase &B);
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index a7cd68e860e467..bc09763d23f297 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -2603,13 +2603,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, FabsCall);
 }
 
-// TODO: Generalize to handle any trig function and its inverse.
-Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
+Value *LibCallSimplifier::optimizeTrig(CallInst *CI, IRBuilderBase &B) {
   Module *M = CI->getModule();
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
   StringRef Name = Callee->getName();
-  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name))
+
+  // Map of trigonometric functions to their inverses.
+  static const std::map<std::string, std::string> TrigFuncMap = {
+      {"sin", "asin"},     {"cos", "acos"},     {"tan", "atan"},
+      {"sinf", "asinf"},   {"cosf", "acosf"},   {"tanf", "atanf"},
+      {"sinl", "asinl"},   {"cosl", "acosl"},   {"tanl", "atanl"},
+      {"sinh", "asin"},    {"cosh", "acosh"},   {"tanh", "atanh"},
+      {"sinhf", "asinf"},  {"coshf", "acoshf"}, {"tanhf", "atanhf"},
+      {"sinhl", "asinhl"}, {"coshl", "acoshl"}, {"tanhl", "atanhl"},
+  };
+
+  // Check if the function is a trigonometric function.
+  auto It = TrigFuncMap.find(Name.str());
+  if (It == TrigFuncMap.end())
+    return Ret;
+
+  // Check if the function has a float version.
+  if (UnsafeFPShrink && hasFloatVersion(M, Name))
     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
 
   Value *Op1 = CI->getArgOperand(0);
@@ -2621,16 +2637,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) {
   if (!CI->isFast() || !OpC->isFast())
     return Ret;
 
-  // tan(atan(x)) -> x
-  // tanf(atanf(x)) -> x
-  // tanl(atanl(x)) -> x
+  // Check if the operand is the inverse of the trigonometric function.
+  // in which case, a chain of inverses can be folded, ie: tan(atan(x)) -> x
   LibFunc Func;
   Function *F = OpC->getCalledFunction();
   if (F && TLI->getLibFunc(F->getName(), Func) &&
-      isLibFuncEmittable(M, TLI, Func) &&
-      ((Func == LibFunc_atan && Callee->getName() == "tan") ||
-       (Func == LibFunc_atanf && Callee->getName() == "tanf") ||
-       (Func == LibFunc_atanl && Callee->getName() == "tanl")))
+      isLibFuncEmittable(M, TLI, Func) && F->getName() == It->second)
     Ret = OpC->getArgOperand(0);
   return Ret;
 }
@@ -3621,10 +3633,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_logb:
   case LibFunc_logbl:
     return optimizeLog(CI, Builder);
-  case LibFunc_tan:
-  case LibFunc_tanf:
-  case LibFunc_tanl:
-    return optimizeTan(CI, Builder);
   case LibFunc_ceil:
     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
   case LibFunc_floor:
@@ -3646,17 +3654,32 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
   case LibFunc_atan:
   case LibFunc_atanh:
   case LibFunc_cbrt:
-  case LibFunc_cosh:
   case LibFunc_exp:
   case LibFunc_exp10:
   case LibFunc_expm1:
+    if (UnsafeFPShrink &&
+        hasFloatVersion(M, CI->getCalledFunction()->getName()))
+      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
+    return nullptr;
   case LibFunc_cos:
+  case LibFunc_cosf:
+  case LibFunc_cosl:
+  case LibFunc_cosh:
+  case LibFunc_coshf:
+  case LibFunc_coshl:
   case LibFunc_sin:
+  case LibFunc_sinf:
+  case LibFunc_sinl:
   case LibFunc_sinh:
+  case LibFunc_sinhf:
+  case LibFunc_sinhl:
+  case LibFunc_tan:
+  case LibFunc_tanf:
+  case LibFunc_tanl:
   case LibFunc_tanh:
-    if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
-      return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
-    return nullptr;
+  case LibFunc_tanhf:
+  case LibFunc_tanhl:
+    return optimizeTrig(CI, Builder);
   case LibFunc_copysign:
     if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
       return optimizeBinaryDoubleFP(CI, Builder, TLI);

@AtariDreams AtariDreams changed the title Resolve FIXME: Generalize optimizeTan to support other trig functions Create more optimizing functions to fold inverse pairs Jan 11, 2024
@AtariDreams AtariDreams changed the title Create more optimizing functions to fold inverse pairs [Transforms] Create more optimizing functions to fold inverse pairs Jan 11, 2024
@AtariDreams AtariDreams changed the title [Transforms] Create more optimizing functions to fold inverse pairs [Transforms] Create more optimizing functions to fold inverse trig pairs Jan 11, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 11, 2024

Could you please add some regression tests?

@AtariDreams
Copy link
Contributor Author

AtariDreams commented Jan 11, 2024

Could you please add some regression tests?

Done!

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/trig-nofastmath.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/trig-nofastmath.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/trig-nofastmath.ll Outdated Show resolved Hide resolved
@AtariDreams AtariDreams changed the title [Transforms] Create more optimizing functions to fold inverse trig pairs [Transforms] Expand optimizeTan to fold more inverse trig pairs Jan 12, 2024
@AtariDreams
Copy link
Contributor Author

@arsenm Can you review this?

@AtariDreams
Copy link
Contributor Author

@dtcxzyw Is this ready to merge?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 22, 2024

@dtcxzyw Is this ready to merge?

Please wait for approval from @arsenm or @jcranmer-intel.

@AtariDreams
Copy link
Contributor Author

@arsenm

@AtariDreams
Copy link
Contributor Author

@dtcxzyw Ready!

Merge tan-nofastmath.ll and tan.ll into trig.ll
optimizeTan has been renamed to optimizeTrigInversionPairs as a result.

Sadly, this is not mathematically true that all inverse pairs fold to x.

For example, asin(sin(x)) does not fold to x if x is over 2*pi.
@AtariDreams
Copy link
Contributor Author

@arsenm Can we merge this please?

@arsenm arsenm merged commit c6b5ea3 into llvm:main Feb 6, 2024
4 checks passed
@AtariDreams AtariDreams deleted the trig branch February 6, 2024 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants