Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL #87474

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sushgokh
Copy link
Contributor

@sushgokh sushgokh commented Apr 3, 2024

The proposed patch, in general, tries to transform the below code sequence:
x = 1.0 / sqrt (a);
r1 = x * x; // same as 1.0 / a
r2 = a / sqrt(a); // same as sqrt (a)

TO

(If x, r1 and r2 are all used further in the code)
r1 = 1.0 / a
r2 = sqrt (a)
x = r1 * r2

The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication.

The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was:
544.nab_r ~4%

No other regressions were observed. Also, no compile time differences were observed with the patch.

Closes #54652

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
@arsenm
Copy link
Contributor

arsenm commented Apr 3, 2024

I tried adjusting the comment to use the original variable names instead of expressing it as assignment of the original names

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (sushgokh)

Changes

The proposed patch, in general, tries to transform the below code sequence:
x = 1.0 / sqrt (a);
r1 = x * x; // same as 1.0 / a
r2 = a * x; // same as sqrt (a)

TO

(If x, r1 and r2 are all used further in the code)
r1 = 1.0 / a
r2 = sqrt (a)
x = r1 * r2

The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication.

The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was:
544.nab_r ~4%

No other regressions were observed. Also, no compile time differences were observed with the patch.

Closes #54652


Patch is 27.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87474.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+174-3)
  • (added) llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll (+463)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 8c698e52b5a0e6..bfe65264738c4d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -626,6 +626,127 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
   return nullptr;
 }
 
+bool isFSqrtDivToFMulLegal(Instruction *X, SmallSetVector<Instruction *, 2> &R1,
+                           SmallSetVector<Instruction *, 2> &R2) {
+
+  BasicBlock *BBx = X->getParent();
+  BasicBlock *BBr1 = R1[0]->getParent();
+  BasicBlock *BBr2 = R2[0]->getParent();
+
+  auto IsStrictFP = [](Instruction *I) {
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
+    return II && II->isStrictFP();
+  };
+
+  // Check the constaints on instruction X.
+  auto XConstraintsSatisfied = [X, &IsStrictFP]() {
+    if (IsStrictFP(X))
+      return false;
+    // X must atleast have 4 uses.
+    // 3 uses as part of
+    //    r1 = x * x
+    //    r2 = a * x
+    // Now, post-transform, r1/r2 will no longer have usage of 'x' and if the
+    // changes to 'x' need to persist, we must have one more usage of 'x'
+    if (!X->hasNUsesOrMore(4))
+      return false;
+    // Check if reciprocalFP is enabled.
+    bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal();
+    return RecipFPMath;
+  };
+  if (!XConstraintsSatisfied())
+    return false;
+
+  // Check the constraints on instructions in R1.
+  auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
+    if (IsStrictFP(I))
+      return false;
+    // When you have multiple instructions residing in R1 and R2 respectively,
+    // it's difficult to generate combinations of (R1,R2) and then check if we
+    // have the required pattern. So, for now, just be conservative.
+    if (I->getParent() != BBr1)
+      return false;
+    if (!I->hasNUsesOrMore(1))
+      return false;
+    // The optimization tries to convert
+    // R1 = div * div    where, div = 1/sqrt(a)
+    // to
+    // R1 = 1/a
+    // Now, this simplication does not work because sqrt(a)=NaN when a<0
+    if (!I->hasNoNaNs())
+      return false;
+    // sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
+    // the result.
+    return I->hasNoSignedZeros();
+  };
+  if (!std::all_of(R1.begin(), R1.end(), R1ConstraintsSatisfied))
+    return false;
+
+  // Check the constraints on instructions in R2.
+  auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
+    if (IsStrictFP(I))
+      return false;
+    // When you have multiple instructions residing in R1 and R2 respectively,
+    // it's difficult to generate combination of (R1,R2) and then check if we
+    // have the required pattern. So, for now, just be conservative.
+    if (I->getParent() != BBr2)
+      return false;
+    if (!I->hasNUsesOrMore(1))
+      return false;
+    // This simplication changes
+    // R2 = a * 1/sqrt(a)
+    // to
+    // R2 = sqrt(a)
+    // Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
+    // instead of NaN.
+    return I->hasNoSignedZeros();
+  };
+  if (!std::all_of(R2.begin(), R2.end(), R2ConstraintsSatisfied))
+    return false;
+
+  // Check the constraints on X, R1 and R2 combined.
+  // fdiv instruction and one of the multiplications must reside in the same
+  // block. If not, the optimized code may execute more ops than before and
+  // this may hamper the performance.
+  return (BBx == BBr1 || BBx == BBr2);
+}
+
+void getFSqrtDivOptPattern(Value *Div, SmallSetVector<Instruction *, 2> &R1,
+                           SmallSetVector<Instruction *, 2> &R2) {
+  Value *A;
+  if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
+      match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
+    for (auto U : Div->users()) {
+      Instruction *I = dyn_cast<Instruction>(U);
+      if (!(I && I->getOpcode() == Instruction::FMul))
+        continue;
+
+      if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
+        R1.insert(I);
+        continue;
+      }
+
+      Value *X;
+      if (match(I, m_FMul(m_Specific(Div), m_Value(X))) && X == A) {
+        R2.insert(I);
+        continue;
+      }
+
+      if (match(I, m_FMul(m_Value(X), m_Specific(Div))) && X == A) {
+        R2.insert(I);
+        continue;
+      }
+    }
+  }
+}
+
+bool delayFMulSqrtTransform(Value *Div) {
+  SmallSetVector<Instruction *, 2> R1, R2;
+  getFSqrtDivOptPattern(Div, R1, R2);
+  return (!(R1.empty() || R2.empty()) &&
+          isFSqrtDivToFMulLegal((Instruction *)Div, R1, R2));
+}
+
 Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0);
   Value *Op1 = I.getOperand(1);
@@ -705,11 +826,11 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
   // has the necessary (reassoc) fast-math-flags.
   if (I.hasNoSignedZeros() &&
       match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-      match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
+      match(Y, m_Sqrt(m_Value(X))) && Op1 == X && !delayFMulSqrtTransform(Op0))
     return BinaryOperator::CreateFDivFMF(X, Y, &I);
   if (I.hasNoSignedZeros() &&
       match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-      match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
+      match(Y, m_Sqrt(m_Value(X))) && Op0 == X && !delayFMulSqrtTransform(Op1))
     return BinaryOperator::CreateFDivFMF(X, Y, &I);
 
   // Like the similar transform in instsimplify, this requires 'nsz' because
@@ -717,7 +838,8 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
   if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
     // Peek through fdiv to find squaring of square root:
     // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
-    if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
+    if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y)))) &&
+        !delayFMulSqrtTransform(Op0)) {
       Value *XX = Builder.CreateFMulFMF(X, X, &I);
       return BinaryOperator::CreateFDivFMF(XX, Y, &I);
     }
@@ -1796,6 +1918,35 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
   return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
 }
 
+Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
+                               SmallSetVector<Instruction *, 2> &R1,
+                               SmallSetVector<Instruction *, 2> &R2,
+                               Value *SqrtOp, InstCombiner::BuilderTy &B) {
+
+  // 1. synthesize tmp1 = 1/a and replace uses of r1
+  B.SetInsertPoint(X);
+  Value *Tmp1 =
+      B.CreateFDivFMF(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp, R1[0]);
+  for (auto *I : R1)
+    I->replaceAllUsesWith(Tmp1);
+
+  // 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
+  // uses of r2 with tmp2
+  for (auto *I : R2)
+    I->replaceAllUsesWith(CI);
+
+  // 3. synthesize tmp3  = tmp1 * tmp2 . Replace uses of 'x' with tmp3
+  Value *Tmp3;
+  // If x = -1/sqrt(a) initially,then Tmp3 = -(Tmp1*tmp2)
+  if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
+    Value *Mul = B.CreateFMul(Tmp1, CI);
+    Tmp3 = B.CreateFNegFMF(Mul, X);
+  } else
+    Tmp3 = B.CreateFMulFMF(Tmp1, CI, X);
+
+  return Tmp3;
+}
+
 Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
   Module *M = I.getModule();
 
@@ -1820,6 +1971,26 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
     return R;
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+
+  // Convert
+  // x = 1.0/sqrt(a)
+  // r1 = x * x;
+  // r2 = a * x;
+  //
+  // TO
+  //
+  // r1 = 1/a
+  // r2 = sqrt(a)
+  // x = r1 * r2
+  SmallSetVector<Instruction *, 2> R1, R2;
+  getFSqrtDivOptPattern(&I, R1, R2);
+  if (!(R1.empty() || R2.empty()) && isFSqrtDivToFMulLegal(&I, R1, R2)) {
+    CallInst *CI = (CallInst *)((&I)->getOperand(1));
+    Value *SqrtOp = CI->getArgOperand(0);
+    if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
+      return replaceInstUsesWith(I, D);
+  }
+
   if (isa<Constant>(Op0))
     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
       if (Instruction *R = FoldOpIntoSelect(I, SI))
diff --git a/llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll b/llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll
new file mode 100644
index 00000000000000..4852337d4b6586
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fsqrtdiv-transform.ll
@@ -0,0 +1,463 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes='instcombine<no-verify-fixpoint>' < %s | FileCheck %s
+
+@x = global double 0.000000e+00
+@r1 = global double 0.000000e+00
+@r2 = global double 0.000000e+00
+@r3 = global double 0.000000e+00
+
+; div/mul/mul1 all in the same block.
+define void @bb_constraint_case1(double %a) {
+; CHECK-LABEL: define void @bb_constraint_case1(
+; CHECK-SAME: double [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fdiv nnan nsz double 1.000000e+00, [[A]]
+; CHECK-NEXT:    [[DIV:%.*]] = fmul arcp double [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    store double [[TMP1]], ptr @r1, align 8
+; CHECK-NEXT:    store double [[TMP0]], ptr @r2, align 8
+; CHECK-NEXT:    ret void
+entry:
+  %sqrt = tail call double @llvm.sqrt.f64(double %a)
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  store double %div, ptr @x
+  %mul = fmul nnan nsz double %div, %div
+  store double %mul, ptr @r1
+  %mul1 = fmul nsz double %a, %div
+  store double %mul1, ptr @r2
+  ret void
+}
+; div/mul in one block and mul1 in other block with conditional guard.
+define void @bb_constraint_case2(double %a, i32 %d) {
+; CHECK-LABEL: define void @bb_constraint_case2(
+; CHECK-SAME: double [[A:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fdiv nnan nsz double 1.000000e+00, [[A]]
+; CHECK-NEXT:    [[DIV:%.*]] = fmul arcp double [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    store double [[TMP1]], ptr @r1, align 8
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[D]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    store double [[TMP0]], ptr @r2, align 8
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    ret void
+entry:
+  %sqrt = call double @llvm.sqrt.f64(double %a)
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  store double %div, ptr @x
+  %mul = fmul nnan nsz double %div, %div
+  store double %mul, ptr @r1
+  %tobool.not = icmp eq i32 %d, 0
+  br i1 %tobool.not, label %if.end, label %if.then
+
+if.then:                                          ; preds = %entry
+  %mul1 = fmul nsz double %div, %a
+  store double %mul1, ptr @r2
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %entry
+  ret void
+}
+
+; div in one block. mul/mul1 in other block and conditionally guarded. Don't optimize.
+define void @bb_constraint_case3(double %a, i32 %d) {
+; CHECK-LABEL: define void @bb_constraint_case3(
+; CHECK-SAME: double [[A:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv arcp double 1.000000e+00, [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[D]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[MUL:%.*]] = fmul nnan nsz double [[DIV]], [[DIV]]
+; CHECK-NEXT:    store double [[MUL]], ptr @r1, align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load double, ptr @x, align 8
+; CHECK-NEXT:    [[MUL1:%.*]] = fmul nsz double [[TMP1]], [[A]]
+; CHECK-NEXT:    store double [[MUL1]], ptr @r2, align 8
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    ret void
+entry:
+  %sqrt = call double @llvm.sqrt.f64(double %a)
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  store double %div, ptr @x
+  %tobool = icmp ne i32 %d, 0
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %mul = fmul nnan nsz double %div, %div
+  store double %mul, ptr @r1
+  %1 = load double, ptr @x
+  %mul1 = fmul nsz double %a, %1
+  store double %mul1, ptr @r2
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %entry
+  ret void
+}
+
+; div in one block. mul/mul3 each in different block and conditionally guarded. Don't optimize.
+define void @bb_constraint_case4(double %a, i32 %c, i32 %d) {
+; CHECK-LABEL: define void @bb_constraint_case4(
+; CHECK-SAME: double [[A:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv arcp double 1.000000e+00, [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[C]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[MUL:%.*]] = fmul nnan nsz double [[DIV]], [[DIV]]
+; CHECK-NEXT:    store double [[MUL]], ptr @r1, align 8
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[TOBOOL1_NOT:%.*]] = icmp eq i32 [[D]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL1_NOT]], label [[IF_END4:%.*]], label [[IF_THEN2:%.*]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    [[TMP1:%.*]] = load double, ptr @x, align 8
+; CHECK-NEXT:    [[MUL3:%.*]] = fmul nsz double [[TMP1]], [[A]]
+; CHECK-NEXT:    store double [[MUL3]], ptr @r2, align 8
+; CHECK-NEXT:    br label [[IF_END4]]
+; CHECK:       if.end4:
+; CHECK-NEXT:    ret void
+entry:
+  %sqrt = call double @llvm.sqrt.f64(double %a)
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  store double %div, ptr @x
+  %tobool = icmp ne i32 %c, 0
+  br i1 %tobool, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %mul = fmul nnan nsz double %div, %div
+  store double %mul, ptr @r1
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %entry
+  %tobool1 = icmp ne i32 %d, 0
+  br i1 %tobool1, label %if.then2, label %if.end4
+
+if.then2:                                         ; preds = %if.end
+  %1 = load double, ptr @x
+  %mul3 = fmul nsz double %a, %1
+  store double %mul3, ptr @r2
+  br label %if.end4
+
+if.end4:                                          ; preds = %if.then2, %if.end
+  ret void
+}
+
+; sqrt value comes from different blocks. Don't optimize.
+define void @bb_constraint_case5(double %a, i32 %c) {
+; CHECK-LABEL: define void @bb_constraint_case5(
+; CHECK-SAME: double [[A:%.*]], i32 [[C:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[C]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_ELSE:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[TMP0:%.*]] = call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    br label [[IF_END:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    [[ADD:%.*]] = fadd double [[A]], 1.000000e+01
+; CHECK-NEXT:    [[TMP1:%.*]] = call double @llvm.sqrt.f64(double [[ADD]])
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[DOTPN:%.*]] = phi double [ [[TMP0]], [[IF_THEN]] ], [ [[TMP1]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv arcp double 1.000000e+00, [[DOTPN]]
+; CHECK-NEXT:    [[MUL:%.*]] = fmul nnan nsz double [[DIV]], [[DIV]]
+; CHECK-NEXT:    store double [[MUL]], ptr @r1, align 8
+; CHECK-NEXT:    [[MUL2:%.*]] = fmul nsz double [[DIV]], [[A]]
+; CHECK-NEXT:    store double [[MUL2]], ptr @r2, align 8
+; CHECK-NEXT:    ret void
+entry:
+  %tobool = icmp ne i32 %c, 0
+  br i1 %tobool, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %0 = call double @llvm.sqrt.f64(double %a)
+  br label %if.end
+
+if.else:                                          ; preds = %entry
+  %add = fadd double %a, 1.000000e+01
+  %1 = call double @llvm.sqrt.f64(double %add)
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %sqrt = phi double[ %0, %if.then], [ %1, %if.else]
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  %mul = fmul nnan nsz double %div, %div
+  store double %mul, ptr @r1
+  %mul2 = fmul nsz double %a, %div
+  store double %mul2, ptr @r2
+  ret void
+}
+
+; div in one block and conditionally guarded. mul/mul1 in other block. Don't optimize.
+define void @bb_constraint_case6(double %a, i32 %d) {
+; CHECK-LABEL: define void @bb_constraint_case6(
+; CHECK-SAME: double [[A:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[D]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_END:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       entry.if.end_crit_edge:
+; CHECK-NEXT:    [[DOTPRE:%.*]] = load double, ptr @x, align 8
+; CHECK-NEXT:    br label [[IF_END1:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv arcp double 1.000000e+00, [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    br label [[IF_END1]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[TMP1:%.*]] = phi double [ [[DOTPRE]], [[IF_END]] ], [ [[DIV]], [[IF_THEN]] ]
+; CHECK-NEXT:    [[MUL:%.*]] = fmul nnan nsz double [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    store double [[MUL]], ptr @r1, align 8
+; CHECK-NEXT:    [[MUL1:%.*]] = fmul nsz double [[TMP1]], [[A]]
+; CHECK-NEXT:    store double [[MUL1]], ptr @r2, align 8
+; CHECK-NEXT:    ret void
+entry:
+  %tobool.not = icmp eq i32 %d, 0
+  br i1 %tobool.not, label %entry.if.end_crit_edge, label %if.then
+
+entry.if.end_crit_edge:                           ; preds = %entry
+  %.pre = load double, ptr @x
+  br label %if.end
+
+if.then:                                          ; preds = %entry
+  %sqrt = tail call double @llvm.sqrt.f64(double %a)
+  %div = fdiv arcp double 1.000000e+00, %sqrt
+  store double %div, ptr @x
+  br label %if.end
+
+if.end:                                           ; preds = %entry.if.end_crit_edge, %if.then
+  %1 = phi double [ %.pre, %entry.if.end_crit_edge ], [ %div, %if.then ]
+  %mul = fmul nnan nsz double %1, %1
+  store double %mul, ptr @r1
+  %mul1 = fmul nsz double %1, %a
+  store double %mul1, ptr @r2
+  ret void
+}
+
+; value for first mul(i.e. div4.sink) comes from different blocks. Don't optimize.
+define void @bb_constraint_case7(double %a, i32 %c, i32 %d) {
+; CHECK-LABEL: define void @bb_constraint_case7(
+; CHECK-SAME: double [[A:%.*]], i32 [[C:%.*]], i32 [[D:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call double @llvm.sqrt.f64(double [[A]])
+; CHECK-NEXT:    [[DIV:%.*]] = fdiv arcp double 1.000000e+00, [[TMP0]]
+; CHECK-NEXT:    store double [[DIV]], ptr @x, align 8
+; CHECK-NEXT:    [[TOBOOL_NOT:%.*]] = icmp eq i32 [[C]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL_NOT]], label [[IF_ELSE:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[DIV1:%.*]] = fdiv double 3.000000e+00, [[A]]
+; CHECK-NEXT:    br label [[IF_END6:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    [[TOBOOL2_NOT:%.*]] = icmp eq i32 [[D]], 0
+; CHECK-NEXT:    br i1 [[TOBOOL2_NOT]], label [[IF_ELSE5:%.*]], label [[IF_THEN3:%.*]]
+; CHECK:       if.then3:
+; CHECK-NEXT:    [[DIV4:%.*]] = fdiv double 2.000000e+00, [[A]]
+; CHECK-NEXT:    br label [[IF_END6]]
+; CHECK:       if.else5:
+; CHECK-NEXT:    [[MUL:%.*]] = fmul nnan nsz double [[DIV]], [[DIV]]
+; CHECK-NEXT:    br label [[IF_END6]]
+; CHECK:       if.end6:
+; CHECK-NEXT:    [[DIV4_SINK:%.*]] = phi double [ [[DIV4]], [[IF_THEN3]] ], [ [[MUL]], [[IF_ELSE5]] ], [ [[DIV1]], [[IF_THEN]] ]
+; CHECK-NEXT:    store double [[DIV4_SINK]], ptr @r1, align 8
+; CHECK-NEXT:    [[MUL7:%.*]] = fmul nsz double [[DIV]], [[A]]
+; CHECK-NEXT:    store double [[MUL7]], ptr @r2, align 8
+; CHECK-NEXT:    ret void
+entry:
+  %sqrt = tail call double @llvm.sqrt.f64(double %a)
...
[truncated]

@sushgokh sushgokh requested a review from arsenm April 5, 2024 07:50
Copy link
Contributor

@jcranmer-intel jcranmer-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't gone through all of the thinking on fast-math flags yet, but I've noted at least one incorrect flag:

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(drive-by comments only, I don't review FP transforms.)

@sushgokh
Copy link
Contributor Author

@arsenm @jcranmer-intel I have tried addressing all the floating point issues. Are there any more issues I need to address?

The proposed patch, in general, tries to transform the below code sequence:
x = 1.0 / sqrt (a);
r1 = x * x;  // same as 1.0 / a
r2 = a / sqrt(a); // same as sqrt (a)

TO

(If x, r1 and r2 are all used further in the code)
tmp1 = 1.0 / a
tmp2 = sqrt (a)
tmp3 = tmp1 * tmp2
x = tmp3
r1 = tmp1
r2 = tmp2

The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication.

The patch was tested with SPEC17 suite with cpu=neoverse-v2.
The performance uplift achieved was:
544.nab_r   ~4%

No other regressions were observed. Also, no compile time differences were observed with the patch.

Closes llvm#54652
@sushgokh
Copy link
Contributor Author

ping @arsenm @jcranmer-intel

@sushgokh
Copy link
Contributor Author

ping @arsenm @jcranmer-intel Is there anything more I need to look into?

Comment on lines +646 to +647
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this require nnan? For a nan input a, r1, r2 and x are trivially nan in the input and output. For a < 0, in the input, x = nan, r1 = nan, r2 = nan. In the output, r1 = non-nan, but this is OK if the single use is the multiply to x. This would also be OK if the multiply had nnan instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization is valid only for positive normals. Now, we cant put restrictions on values of a so I had to put constraints on call instruction since this is used for all x/r1/r2. Also, x/r1/r2 can have multiple uses and hence, their values before/after transform need to be matched

ret void
}
declare double @llvm.sqrt.f64(double)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some cases with fdiv -1, x? Also test a vector case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some cases with fdiv -1, x?

already goes by the name negative_fdiv_val.

Also test a vector case

Also, why a vector case would be needed? I know that InstCombine runs even after vectorization but this transform would take well before vectorization, right?

}

; missing flags for optimization.
define void @missing_flags_on_mul(double %a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name specifically which flags are missing?

@fhahn
Copy link
Contributor

fhahn commented May 31, 2024

I am wondering if instcombine is the right place to reason about this? At this point, we don't have any info on the cost of instructions. This seems like it would be better placed in the backend, e.g. as a pattern MachineCombiner uses (which already tries to reassociate expressions, if it improves the critical path or other metrics)

@arsenm
Copy link
Contributor

arsenm commented May 31, 2024

I am wondering if instcombine is the right place to reason about this? At this point, we don't have any info on the cost of instructions.

I think it's entirely reasonable to assume fmul is universally cheaper than fdiv. Plus the infrastructure for doing anything in the backend is significantly worse

This seems like it would be better placed in the backend, e.g. as a pattern MachineCombiner uses (which already tries to reassociate expressions, if it improves the critical path or other metrics)

This sounds unworkable for any target with nontrivial fdiv expansions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Avoid dependent FSQRT and FDIV where possible -freciprocal-math and -funsafe-math-optimizations
8 participants