Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fix the correctness of missing check reassoc attribute #71277

Closed
wants to merge 1 commit into from

Conversation

vfdff
Copy link
Contributor

@vfdff vfdff commented Nov 4, 2023

The potential issue is based on the discussion of PR69998 .
The Transfrom is reasonable when the I and all of its operands have the reassoc attribute.
Also add some reassoc to the original test case to retain the original optimization logic.

Note:
The IR node may have different fast math flags within a function if you're doing LTO.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 4, 2023

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Allen (vfdff)

Changes

The potential issue is based on the discussion of PR69998 . The Transfrom is reasonable
when the I and all of its operands have the reassoc attribute.

The IR node may have different attribute within a function if you're doing LTO.


Patch is 36.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71277.diff

10 Files Affected:

  • (modified) llvm/include/llvm/IR/Instruction.h (+4)
  • (modified) llvm/lib/IR/Instruction.cpp (+9)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+177-170)
  • (modified) llvm/test/Transforms/InstCombine/fmul-exp.ll (+10-10)
  • (modified) llvm/test/Transforms/InstCombine/fmul-exp2.ll (+10-10)
  • (modified) llvm/test/Transforms/InstCombine/fmul-pow.ll (+25-25)
  • (modified) llvm/test/Transforms/InstCombine/fmul-sqrt.ll (+6-6)
  • (modified) llvm/test/Transforms/InstCombine/fmul.ll (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/powi.ll (+11-11)
diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index b5ccdf020a4c006..e166cc0e43fda9b 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -416,6 +416,10 @@ class Instruction : public User,
   /// instruction.
   void setNonNeg(bool b = true);
 
+  /// It checks all of its operands have attribute Reassoc if they are
+  /// instruction.
+  bool hasAllowReassocOfAllOperand() const LLVM_READONLY;
+
   /// Determine whether the no unsigned wrap flag is set.
   bool hasNoUnsignedWrap() const LLVM_READONLY;
 
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 1b3c03348f41a70..895df100f6ff46d 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -177,6 +177,15 @@ void Instruction::setNonNeg(bool b) {
                          (b * PossiblyNonNegInst::NonNeg);
 }
 
+bool Instruction::hasAllowReassocOfAllOperand() const {
+  return all_of(operands(), [](Value *V) {
+    if (!isa<IntrinsicInst>(V))
+      return true;
+    IntrinsicInst *OptI = cast<IntrinsicInst>(V);
+    return OptI->hasAllowReassoc();
+  });
+}
+
 bool Instruction::hasNoUnsignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 34b10220ec88aba..68a8fb676d8d909 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -98,6 +98,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *visitSub(BinaryOperator &I);
   Instruction *visitFSub(BinaryOperator &I);
   Instruction *visitMul(BinaryOperator &I);
+  Instruction *foldFMulReassoc(BinaryOperator &I);
   Instruction *visitFMul(BinaryOperator &I);
   Instruction *visitURem(BinaryOperator &I);
   Instruction *visitSRem(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index bc784390c23be49..d2ac7abc42b9ce6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -560,6 +560,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
+  Value *Op0 = I.getOperand(0);
+  Value *Op1 = I.getOperand(1);
+  Value *X, *Y;
+  Constant *C;
+
+  // Reassociate constant RHS with another constant to form constant
+  // expression.
+  if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
+    Constant *C1;
+    if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
+      // (C1 / X) * C --> (C * C1) / X
+      Constant *CC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
+      if (CC1 && CC1->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(CC1, X, &I);
+    }
+    if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
+      // (X / C1) * C --> X * (C / C1)
+      Constant *CDivC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
+      if (CDivC1 && CDivC1->isNormalFP())
+        return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
+
+      // If the constant was a denormal, try reassociating differently.
+      // (X / C1) * C --> X / (C1 / C)
+      Constant *C1DivC =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
+      if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
+    }
+
+    // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
+    // canonicalized to 'fadd X, C'. Distributing the multiply may allow
+    // further folds and (X * C) + C2 is 'fma'.
+    if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
+      // (X + C1) * C --> (X * C) + (C * C1)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
+      }
+    }
+    if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
+      // (C1 - X) * C --> (C * C1) - (X * C)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
+      }
+    }
+  }
+
+  Value *Z;
+  if (match(&I,
+            m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) {
+    // Sink division: (X / Y) * Z --> (X * Z) / Y
+    Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
+    return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
+  }
+
+  // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+  // nnan disallows the possibility of returning a number if both operands are
+  // negative (in that case, we should return NaN).
+  if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
+      match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
+    Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+    Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
+    return replaceInstUsesWith(I, Sqrt);
+  }
+
+  // The following transforms are done irrespective of the number of uses
+  // for the expression "1.0/sqrt(X)".
+  //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
+  //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
+  // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
+  // has the necessary (reassoc) fast-math-flags.
+  if (I.hasNoSignedZeros() &&
+      match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+  if (I.hasNoSignedZeros() &&
+      match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+
+  // Like the similar transform in instsimplify, this requires 'nsz' because
+  // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
+  if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
+    // Peek through fdiv to find squaring of square root:
+    // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
+    if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(XX, Y, &I);
+    }
+    // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
+    if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(Y, XX, &I);
+    }
+  }
+
+  // pow(X, Y) * X --> pow(X, Y+1)
+  // X * pow(X, Y) --> pow(X, Y+1)
+  if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
+                                                              m_Value(Y))),
+                         m_Deferred(X)))) {
+    Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
+    Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
+    return replaceInstUsesWith(I, Pow);
+  }
+
+  if (I.isOnlyUserOfAnyOperand()) {
+    // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
+      auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+    // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
+      auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // powi(x, y) * powi(x, z) -> powi(x, y + z)
+    if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
+        Y->getType() == Z->getType()) {
+      auto *YZ = Builder.CreateAdd(Y, Z);
+      auto *NewPow = Builder.CreateIntrinsic(
+          Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // exp(X) * exp(Y) -> exp(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
+      return replaceInstUsesWith(I, Exp);
+    }
+
+    // exp2(X) * exp2(Y) -> exp2(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
+      return replaceInstUsesWith(I, Exp2);
+    }
+  }
+
+  // (X*Y) * X => (X*X) * Y where Y != X
+  //  The purpose is two-fold:
+  //   1) to form a power expression (of X).
+  //   2) potentially shorten the critical path: After transformation, the
+  //  latency of the instruction Y is amortized by the expression of X*X,
+  //  and therefore Y is in a "less critical" position compared to what it
+  //  was before the transformation.
+  if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+  if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1),
                                   I.getFastMathFlags(),
@@ -607,176 +781,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
     return replaceInstUsesWith(I, V);
 
-  if (I.hasAllowReassoc()) {
-    // Reassociate constant RHS with another constant to form constant
-    // expression.
-    if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
-      Constant *C1;
-      if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
-        // (C1 / X) * C --> (C * C1) / X
-        Constant *CC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
-        if (CC1 && CC1->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(CC1, X, &I);
-      }
-      if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
-        // (X / C1) * C --> X * (C / C1)
-        Constant *CDivC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
-        if (CDivC1 && CDivC1->isNormalFP())
-          return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
-
-        // If the constant was a denormal, try reassociating differently.
-        // (X / C1) * C --> X / (C1 / C)
-        Constant *C1DivC =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
-        if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
-      }
-
-      // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
-      // canonicalized to 'fadd X, C'. Distributing the multiply may allow
-      // further folds and (X * C) + C2 is 'fma'.
-      if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
-        // (X + C1) * C --> (X * C) + (C * C1)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
-        }
-      }
-      if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
-        // (C1 - X) * C --> (C * C1) - (X * C)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
-        }
-      }
-    }
-
-    Value *Z;
-    if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))),
-                           m_Value(Z)))) {
-      // Sink division: (X / Y) * Z --> (X * Z) / Y
-      Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
-      return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
-    }
-
-    // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
-    // nnan disallows the possibility of returning a number if both operands are
-    // negative (in that case, we should return NaN).
-    if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
-        match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
-      Value *XY = Builder.CreateFMulFMF(X, Y, &I);
-      Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
-      return replaceInstUsesWith(I, Sqrt);
-    }
-
-    // The following transforms are done irrespective of the number of uses
-    // for the expression "1.0/sqrt(X)".
-    //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
-    //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
-    // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
-    // has the necessary (reassoc) fast-math-flags.
-    if (I.hasNoSignedZeros() &&
-        match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-    if (I.hasNoSignedZeros() &&
-        match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-
-    // Like the similar transform in instsimplify, this requires 'nsz' because
-    // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
-    if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
-        Op0->hasNUses(2)) {
-      // Peek through fdiv to find squaring of square root:
-      // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
-      if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(XX, Y, &I);
-      }
-      // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
-      if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(Y, XX, &I);
-      }
-    }
-
-    // pow(X, Y) * X --> pow(X, Y+1)
-    // X * pow(X, Y) --> pow(X, Y+1)
-    if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
-                                                                m_Value(Y))),
-                           m_Deferred(X)))) {
-      Value *Y1 =
-          Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
-      Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
-      return replaceInstUsesWith(I, Pow);
-    }
-
-    if (I.isOnlyUserOfAnyOperand()) {
-      // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
-        auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-      // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
-        auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // powi(x, y) * powi(x, z) -> powi(x, y + z)
-      if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
-          Y->getType() == Z->getType()) {
-        auto *YZ = Builder.CreateAdd(Y, Z);
-        auto *NewPow = Builder.CreateIntrinsic(
-            Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // exp(X) * exp(Y) -> exp(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
-        return replaceInstUsesWith(I, Exp);
-      }
-
-      // exp2(X) * exp2(Y) -> exp2(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
-        return replaceInstUsesWith(I, Exp2);
-      }
-    }
-
-    // (X*Y) * X => (X*X) * Y where Y != X
-    //  The purpose is two-fold:
-    //   1) to form a power expression (of X).
-    //   2) potentially shorten the critical path: After transformation, the
-    //  latency of the instruction Y is amortized by the expression of X*X,
-    //  and therefore Y is in a "less critical" position compared to what it
-    //  was before the transformation.
-    if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) &&
-        Op1 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-    if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) &&
-        Op0 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-  }
+  if (I.hasAllowReassoc() && I.hasAllowReassocOfAllOperand())
+    if (Instruction *FoldedMul = foldFMulReassoc(I))
+      return FoldedMul;
 
   // log2(X * 0.5) * Y = log2(X) * Y - Y
   if (I.isFast()) {
diff --git a/llvm/test/Transforms/InstCombine/fmul-exp.ll b/llvm/test/Transforms/InstCombine/fmul-exp.ll
index 62d22b8c085c267..16066b5d5bc5168 100644
--- a/llvm/test/Transforms/InstCombine/fmul-exp.ll
+++ b/llvm/test/Transforms/InstCombine/fmul-exp.ll
@@ -21,14 +21,14 @@ define double @exp_a_exp_b(double %a, double %b) {
 ; exp(a) * exp(b) reassoc, multiple uses
 define double @exp_a_exp_b_multiple_uses(double %a, double %b) {
 ; CHECK-LABEL: @exp_a_exp_b_multiple_uses(
-; CHECK-NEXT:    [[T1:%.*]] = call double @llvm.exp.f64(double [[B:%.*]])
+; CHECK-NEXT:    [[T1:%.*]] = call reassoc double @llvm.exp.f64(double [[B:%.*]])
 ; CHECK-NEXT:    [[TMP1:%.*]] = fadd reassoc double [[A:%.*]], [[B]]
 ; CHECK-NEXT:    [[MUL:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    call void @use(double [[T1]])
 ; CHECK-NEXT:    ret double [[MUL]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
-  %t1 = call double @llvm.exp.f64(double %b)
+  %t = call reassoc double @llvm.exp.f64(double %a)
+  %t1 = call reassoc double @llvm.exp.f64(double %b)
   %mul = fmul reassoc double %t, %t1
   call void @use(double %t1)
   ret double %mul
@@ -59,8 +59,8 @@ define double @exp_a_exp_b_reassoc(double %a, double %b) {
 ; CHECK-NEXT:    [[MUL:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    ret double [[MUL]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
-  %t1 = call double @llvm.exp.f64(double %b)
+  %t = call reassoc double @llvm.exp.f64(double %a)
+  %t1 = call reassoc double @llvm.exp.f64(double %b)
   %mul = fmul reassoc double %t, %t1
   ret double %mul
 }
@@ -71,7 +71,7 @@ define double @exp_a_a(double %a) {
 ; CHECK-NEXT:    [[M:%.*]] = call reassoc double @llvm.exp.f64(double [[TMP1]])
 ; CHECK-NEXT:    ret double [[M]]
 ;
-  %t = call double @llvm.exp.f64(double %a)
+  %t = call reassoc double @llvm.exp.f64(double %a)
   %m = fmul reassoc double %t, %t
   ret double %m
 }
@@ -100,12 +100,12 @@ define double @exp_a_exp_b_exp_c_ex...
[truncated]

@arsenm
Copy link
Contributor

arsenm commented Nov 7, 2023

The IR node may have different attribute within a function if you're doing LTO.

Should say fast math flags, attributes are a different thing

llvm/lib/IR/Instruction.cpp Outdated Show resolved Hide resolved
llvm/lib/IR/Instruction.cpp Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp Outdated Show resolved Hide resolved
The potential issue is based on the discussion of PR69998. The Transfrom is
reasonable when the I and all of its operands have the reassoc flag.
Also add some reassoc to the original test case to retain the original
optimization logic.

NOTE:
The IR node may have different fast math flags within a function if you're
doing LTO.
@vfdff
Copy link
Contributor Author

vfdff commented Nov 15, 2023

ping ?

Comment on lines +326 to +329
if (!isa<FPMathOperator>(V))
return true;

auto *FPOp = cast<FPMathOperator>(V);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dyn_cast instead of isa+cast

return FPOp->hasAllowReassoc();

default:
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default should be false?

bool Instruction::hasAllowReassocOfAllOperand() const {
return all_of(operands(), [](Value *V) {
if (!isa<FPMathOperator>(V))
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default should be false?


auto *FPOp = cast<FPMathOperator>(V);
switch (FPOp->getOpcode()) {
case Instruction::FNeg:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really need the opcode checks?

@@ -781,7 +781,7 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

if (I.hasAllowReassoc())
if (I.hasAllowReassoc() && I.hasAllowReassocOfAllOperand())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit too specific of a query to add to Instruction. It will also be imprecise in cases where the source is a constant / load /argument or other non-instruction source

@vfdff
Copy link
Contributor Author

vfdff commented Mar 27, 2024

This is hard to fix with a unified approach, and need continuous improvement, such as PR86428, so close it now

@vfdff vfdff closed this Mar 27, 2024
@arsenm
Copy link
Contributor

arsenm commented Mar 27, 2024

This is hard to fix with a unified approach, and need continuous improvement, such as PR86428, so close it now

I don't follow why this is being closed. What do you mean unified approach?

@vfdff
Copy link
Contributor Author

vfdff commented Mar 27, 2024

I think it is difficult to accurately describe which operands need to have the reassoc attribute. A case for point

define float @reassoc_common_operand1(float %x, float %y) {
  %mul0 = fmul float %y, 2.0
  %mul1 = fmul reassoc float %x, 1.0
  %mul2 = fmul reassoc float %mul1, %mul0
  ret float %mul2
}

We may think the above case can't be fold because the %mul0 missing reassoc attribute, but in fact we can fold the %mul1 into %mul2. Therefore, it is not appropriate to ensure that all operands have reassoc attribute as a prerequisite for optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants