Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold copysign of selects from sign comparison to sign operand #85627

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Krishna-13-cyber
Copy link
Contributor

@Krishna-13-cyber Krishna-13-cyber commented Mar 18, 2024

This is currently under development and needs some assistance/review to go about solving this.
This PR intends to solve this issue #64884.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 18, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Krishna Narayanan (Krishna-13-cyber)

Changes

This is currently under development and needs some assistance/review to go about solving this.
This PR intends to solve this issue #64884.
Blocker:
I pattern matched the most probable scenario according to me. This approach doesn't give the desired outcome of folding instructions and fails tests.
Any lead on where I am going wrong/could improve?


Full diff: https://github.com/llvm/llvm-project/pull/85627.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+45)
  • (modified) llvm/test/Transforms/InstCombine/fcmp.ll (+16)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..49e10646eb5d15 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2790,6 +2790,47 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
   return ChangedFMF ? &SI : nullptr;
 }
 
+// Canonicalize select with fcmp -> select
+static Instruction *foldSelectWithFCmp(SelectInst &SI, InstCombinerImpl &IC) {
+  /* From
+    %4 = fcmp olt float %1, 0.000000e+00
+    %5 = and i1 %4, %0
+    %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+  */
+  /* To
+   %4 = select i1 %0, float %1, float 1.000000e+00
+  */
+  Value *CondVal = SI.getCondition();
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+  Value *One = Constant::getAllOnesValue(FalseVal->getType());
+  Value *X, *C, *Op;
+  const APFloat *A, *E;
+  CmpInst::Predicate Pred;
+  for (bool Swap : {false, true}) {
+    if (Swap)
+      std::swap(TrueVal, FalseVal);
+    if (match(&SI, (m_Value(CondVal), m_APFloat(A), m_APFloat(E)))) {
+      if (!match(TrueVal, m_APFloatAllowUndef(A)) &&
+          !match(FalseVal, m_APFloatAllowUndef(E)))
+        return nullptr;
+      if (!match(CondVal, m_And(m_FCmp(Pred, m_Specific(X), m_PosZeroFP()),
+                                m_Value(C))) &&
+          (X->hasOneUse() && C->hasOneUse()))
+        return nullptr;
+      if (!A->isNegative() && E->isNegative())
+        return nullptr;
+      if (!Swap && (Pred == FCmpInst::FCMP_OLT)) {
+        return SelectInst::Create(C, X, One);
+      }
+      if (Swap && (Pred == FCmpInst::FCMP_OGT)) {
+        return SelectInst::Create(C, X, One);
+      }
+    }
+  }
+  return nullptr;
+}
+
 // Match the following IR pattern:
 //   %x.lowbits = and i8 %x, %lowbitmask
 //   %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
@@ -3508,6 +3549,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
     return Fabs;
 
+  // Fold selecting to ffold.
+  if (Instruction *Ffold = foldSelectWithFCmp(SI, *this))
+    return Ffold;
+
   // See if we are selecting two values based on a comparison of the two values.
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
     if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 159c84d0dd8aa9..574fb1f9417742 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -1284,3 +1284,19 @@ define <1 x i1> @bitcast_1vec_eq0(i32 %x) {
   %cmp = fcmp oeq <1 x float> %f, zeroinitializer
   ret <1 x i1> %cmp
 }
+
+define float @copysign_conditional(i1 noundef zeroext %0, float %1, float %2) {
+; CHECK-LABEL: define float @copysign_conditional(
+; CHECK-SAME: i1 noundef zeroext [[TMP0:%.*]], float [[TMP1:%.*]], float [[TMP2:%.*]]) {
+; CHECK-NEXT:    [[TMP4:%.*]] = fcmp olt float [[TMP1]], 0.000000e+00
+; CHECK-NEXT:    [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP0]]
+; CHECK-NEXT:    [[TMP6:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT:    [[TMP7:%.*]] = tail call float @llvm.copysign.f32(float [[TMP2]], float [[TMP6]])
+; CHECK-NEXT:    ret float [[TMP7]]
+;
+  %4 = fcmp olt float %1, 0.000000e+00
+  %5 = and i1 %4, %0
+  %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+  %7 = tail call float @llvm.copysign.f32(float %2, float %6)
+  ret float %7
+}
\ No newline at end of file

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@nikic
Copy link
Contributor

nikic commented Mar 19, 2024

The copysign is important here. The transform is valid only because only the sign is demanded, you can't do it for all selects. So the root instruction for this transform should be copysign, not select.

@nikic nikic added the floating-point Floating-point math label Mar 19, 2024
@nikic nikic requested review from dtcxzyw and removed request for nikic March 19, 2024 14:24
llvm/test/Transforms/InstCombine/fcmp.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/fcmp.ll Outdated Show resolved Hide resolved
Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Mar 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

return nullptr;
// Match select ?, TC, FC where the constants are equal but negated.
// Check for these 8 conditions
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ style comments

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/copysign.ll Outdated Show resolved Hide resolved
%sel = select i1 %and, float -1.000000e+00, float 1.000000e+00
%res = tail call float @llvm.copysign.f32(float %z, float %sel)
ret float %res
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more tests covering the range of compare predicates. Also test fast math flag handling. Also should test the select with identical constants on each side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I tested it for a set of predicates, the case where there’s a bit of deviation is in the oeq case, tried to handle it with specific values. On adding the below patch, this leads to a different folding (fabs and fneg) and changes results for all the predicates.
  /*
 oeq
 copysign(Mag, B & (A == 0.0) ? -TC : TC) --> copysign(Mag, 1) B->true.
 copysign(Mag, B & (A == 0.0) ? -TC : TC) --> copysign(Mag, -1) B->false.
  */
    Value *One;
    One = ConstantFP::get(Mag->getType(), 1.0);
   if (Pred == CmpInst::FCMP_OEQ)
     if (match(B, m_Zero()) && TC->isNegative())
       return replaceOperand(*II, 1, One);
     if (!match(B, m_Zero()) && TC->isNegative())
       One = Builder.CreateFNeg(One);
       return replaceOperand(*II, 1, One);
  1. Added fast-math test.
  2. For identical constants(cases), suppose if we have both as negative or positive, it's folded into fabs.
    https://godbolt.org/z/qrfWvr465

@dtcxzyw dtcxzyw marked this pull request as ready for review March 28, 2024 09:22
Comment on lines 2485 to 2486
if (!match(Sign, m_Select((m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()),
m_Value(B))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!match(Sign, m_Select((m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()),
m_Value(B))),
if (!match(Sign, m_Select(m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()),
m_Value(B)),

Extra pair of parentheses?

*/

if (Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULE) {
if (match(A, m_Negative()) && TC->isNegative())
Copy link
Contributor

@arsenm arsenm May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this m_Negative check works. Looking at PatternMatch, I think this only supports integers? Would it be worth add an FP variant of it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that works only for integers. FP variant would be helpful in this case and FP calculations which include non negative constraints. Should we initiate a new thread/issue for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that means this, and the m_ZeroInt checks below are dead code. You could open a separate PR to add an FP version of m_Negative. For FP values we already have m_AnyZeroFP, m_PosZeroFP and m_NegZeroFP

%sel = select i1 %and, float -1.000000e+00, float 1.000000e+00
%res = tail call float @llvm.copysign.f32(float %z, float %sel)
ret float %res
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still missing many tests. Your comment had 8 different instances, so you at minimum need the 8 variants of the test. On top of that, you need the assorted negative tests with swapped constants, commuted operands, and fast math flag handling

*/

if (Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULE) {
if (match(A, m_Negative()) && TC->isNegative())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that means this, and the m_ZeroInt checks below are dead code. You could open a separate PR to add an FP version of m_Negative. For FP values we already have m_AnyZeroFP, m_PosZeroFP and m_NegZeroFP

; CHECK-NEXT: ret float [[RES]]
;
entry:
%cmp = fcmp fast olt float %x, 0.000000e+00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to have minimal flags to show the behavior, not put all flags on

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be easier to review this if you split the baseline tests into a separate commit, as per https://llvm.org/docs/InstCombineContributorGuide.html#precommit-tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants