Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check if operand is div in fold FDivSqrtDivisor #81970

Merged
merged 3 commits into from
Mar 9, 2024

Conversation

zjaffal
Copy link
Contributor

@zjaffal zjaffal commented Feb 16, 2024

This patch fixes the issues introduced in bb5c389.

I moved the check for the instruction to be div before I check for the fast math flags which resolves the crash in

float a, b;
double sqrt();
void c() { b = a / sqrt(a); }

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Zain Jaffal (zjaffal)

Changes

This patch fixes the issues introduced in bb5c389.

I moved the check for the instruction to be div before I check for the fast math flags which resolves the crash in

float a, b;
double sqrt();
void c() { b = a / sqrt(a); }

Full diff: https://github.com/llvm/llvm-project/pull/81970.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+31)
  • (modified) llvm/test/Transforms/InstCombine/fdiv-sqrt.ll (+26-9)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 0bd4b6d1a835af..912d9ac404052a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1706,6 +1706,34 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
   return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
 }
 
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+                                        InstCombiner::BuilderTy &Builder) {
+  // X / sqrt(Y / Z) -->  X * sqrt(Z / Y)
+  if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
+    return nullptr;
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  auto *II = dyn_cast<IntrinsicInst>(Op1);
+  if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+      !II->hasAllowReassoc() || !II->hasAllowReciprocal())
+    return nullptr;
+
+  Value *Y, *Z;
+  auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+  if (!DivOp)
+    return nullptr;
+  if (!match(DivOp, m_FDiv(m_Value(Y), m_Value(Z))))
+    return nullptr;
+  if (!DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
+      !DivOp->hasOneUse())
+    return nullptr;
+  Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+  Value *NewSqrt =
+      Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
+  return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+}
+
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   Module *M = I.getModule();
 
@@ -1813,6 +1841,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
     return Mul;
 
+  if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+    return Mul;
+
   // pow(X, Y) / X --> pow(X, Y-1)
   if (I.hasAllowReassoc() &&
       match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 346271be7da761..58cc7c297e90a0 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -6,9 +6,9 @@ declare double @llvm.sqrt.f64(double)
 define double @sqrt_div_fast(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_fast(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:
@@ -36,9 +36,9 @@ entry:
 define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_reassoc_arcp(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:
@@ -96,9 +96,9 @@ entry:
 define double @sqrt_div_arcp_missing(double %x, double %y, double %z) {
 ; CHECK-LABEL: @sqrt_div_arcp_missing(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[DIV:%.*]] = fdiv reassoc double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT:    [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = fdiv reassoc double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT:    [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret double [[DIV1]]
 ;
 entry:
@@ -173,3 +173,20 @@ entry:
   ret double %div1
 }
 
+; Function Attrs: nounwind ssp uwtable(sync)
+define float @sqrt_non_div_operator(float %a) {
+; CHECK-LABEL: @sqrt_non_div_operator(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CONV:%.*]] = fpext float [[A:%.*]] to double
+; CHECK-NEXT:    [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[CONV]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv fast double [[CONV]], [[SQRT]]
+; CHECK-NEXT:    [[CONV2:%.*]] = fptrunc double [[DIV]] to float
+; CHECK-NEXT:    ret float [[CONV2]]
+;
+entry:
+  %conv = fpext float %a to double
+  %sqrt = call fast double @llvm.sqrt.f64(double %conv)
+  %div = fdiv fast double %conv, %sqrt
+  %conv2 = fptrunc double %div to float
+  ret float %conv2
+}

@zjaffal zjaffal changed the title Zjaffal/fix fold fdiv sqrt check if operand is div in fold FDivSqrtDivisor Feb 16, 2024
llvm/test/Transforms/InstCombine/fdiv-sqrt.ll Outdated Show resolved Hide resolved
@zjaffal zjaffal merged commit f581149 into llvm:main Mar 9, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants