-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Fold copysign of selects from sign comparison to sign operand #85627
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Krishna Narayanan (Krishna-13-cyber) ChangesThis is currently under development and needs some assistance/review to go about solving this. Full diff: https://github.com/llvm/llvm-project/pull/85627.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..49e10646eb5d15 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2790,6 +2790,47 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
return ChangedFMF ? &SI : nullptr;
}
+// Canonicalize select with fcmp -> select
+static Instruction *foldSelectWithFCmp(SelectInst &SI, InstCombinerImpl &IC) {
+ /* From
+ %4 = fcmp olt float %1, 0.000000e+00
+ %5 = and i1 %4, %0
+ %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+ */
+ /* To
+ %4 = select i1 %0, float %1, float 1.000000e+00
+ */
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ Value *One = Constant::getAllOnesValue(FalseVal->getType());
+ Value *X, *C, *Op;
+ const APFloat *A, *E;
+ CmpInst::Predicate Pred;
+ for (bool Swap : {false, true}) {
+ if (Swap)
+ std::swap(TrueVal, FalseVal);
+ if (match(&SI, (m_Value(CondVal), m_APFloat(A), m_APFloat(E)))) {
+ if (!match(TrueVal, m_APFloatAllowUndef(A)) &&
+ !match(FalseVal, m_APFloatAllowUndef(E)))
+ return nullptr;
+ if (!match(CondVal, m_And(m_FCmp(Pred, m_Specific(X), m_PosZeroFP()),
+ m_Value(C))) &&
+ (X->hasOneUse() && C->hasOneUse()))
+ return nullptr;
+ if (!A->isNegative() && E->isNegative())
+ return nullptr;
+ if (!Swap && (Pred == FCmpInst::FCMP_OLT)) {
+ return SelectInst::Create(C, X, One);
+ }
+ if (Swap && (Pred == FCmpInst::FCMP_OGT)) {
+ return SelectInst::Create(C, X, One);
+ }
+ }
+ }
+ return nullptr;
+}
+
// Match the following IR pattern:
// %x.lowbits = and i8 %x, %lowbitmask
// %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
@@ -3508,6 +3549,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
return Fabs;
+ // Fold selecting to ffold.
+ if (Instruction *Ffold = foldSelectWithFCmp(SI, *this))
+ return Ffold;
+
// See if we are selecting two values based on a comparison of the two values.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 159c84d0dd8aa9..574fb1f9417742 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -1284,3 +1284,19 @@ define <1 x i1> @bitcast_1vec_eq0(i32 %x) {
%cmp = fcmp oeq <1 x float> %f, zeroinitializer
ret <1 x i1> %cmp
}
+
+define float @copysign_conditional(i1 noundef zeroext %0, float %1, float %2) {
+; CHECK-LABEL: define float @copysign_conditional(
+; CHECK-SAME: i1 noundef zeroext [[TMP0:%.*]], float [[TMP1:%.*]], float [[TMP2:%.*]]) {
+; CHECK-NEXT: [[TMP4:%.*]] = fcmp olt float [[TMP1]], 0.000000e+00
+; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP0]]
+; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT: [[TMP7:%.*]] = tail call float @llvm.copysign.f32(float [[TMP2]], float [[TMP6]])
+; CHECK-NEXT: ret float [[TMP7]]
+;
+ %4 = fcmp olt float %1, 0.000000e+00
+ %5 = and i1 %4, %0
+ %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+ %7 = tail call float @llvm.copysign.f32(float %2, float %6)
+ ret float %7
+}
\ No newline at end of file
|
|
The copysign is important here. The transform is valid only because only the sign is demanded, you can't do it for all selects. So the root instruction for this transform should be copysign, not select. |
0ac62b9
to
4626d38
Compare
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
return nullptr; | ||
// Match select ?, TC, FC where the constants are equal but negated. | ||
// Check for these 8 conditions | ||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ style comments
%sel = select i1 %and, float -1.000000e+00, float 1.000000e+00 | ||
%res = tail call float @llvm.copysign.f32(float %z, float %sel) | ||
ret float %res | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need more tests covering the range of compare predicates. Also test fast math flag handling. Also should test the select with identical constants on each side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I tested it for a set of predicates, the case where there’s a bit of deviation is in the oeq case, tried to handle it with specific values. On adding the below patch, this leads to a different folding (fabs and fneg) and changes results for all the predicates.
/*
oeq
copysign(Mag, B & (A == 0.0) ? -TC : TC) --> copysign(Mag, 1) B->true.
copysign(Mag, B & (A == 0.0) ? -TC : TC) --> copysign(Mag, -1) B->false.
*/
Value *One;
One = ConstantFP::get(Mag->getType(), 1.0);
if (Pred == CmpInst::FCMP_OEQ)
if (match(B, m_Zero()) && TC->isNegative())
return replaceOperand(*II, 1, One);
if (!match(B, m_Zero()) && TC->isNegative())
One = Builder.CreateFNeg(One);
return replaceOperand(*II, 1, One);
- Added fast-math test.
- For identical constants(cases), suppose if we have both as negative or positive, it's folded into fabs.
https://godbolt.org/z/qrfWvr465
if (!match(Sign, m_Select((m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()), | ||
m_Value(B))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!match(Sign, m_Select((m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()), | |
m_Value(B))), | |
if (!match(Sign, m_Select(m_And(m_FCmp(Pred, m_Value(A), m_PosZeroFP()), | |
m_Value(B)), |
Extra pair of parentheses?
*/ | ||
|
||
if (Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULE) { | ||
if (match(A, m_Negative()) && TC->isNegative()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this m_Negative check works. Looking at PatternMatch, I think this only supports integers? Would it be worth add an FP variant of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that works only for integers. FP variant would be helpful in this case and FP calculations which include non negative constraints. Should we initiate a new thread/issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that means this, and the m_ZeroInt checks below are dead code. You could open a separate PR to add an FP version of m_Negative. For FP values we already have m_AnyZeroFP, m_PosZeroFP and m_NegZeroFP
%sel = select i1 %and, float -1.000000e+00, float 1.000000e+00 | ||
%res = tail call float @llvm.copysign.f32(float %z, float %sel) | ||
ret float %res | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still missing many tests. Your comment had 8 different instances, so you at minimum need the 8 variants of the test. On top of that, you need the assorted negative tests with swapped constants, commuted operands, and fast math flag handling
*/ | ||
|
||
if (Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULE) { | ||
if (match(A, m_Negative()) && TC->isNegative()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that means this, and the m_ZeroInt checks below are dead code. You could open a separate PR to add an FP version of m_Negative. For FP values we already have m_AnyZeroFP, m_PosZeroFP and m_NegZeroFP
; CHECK-NEXT: ret float [[RES]] | ||
; | ||
entry: | ||
%cmp = fcmp fast olt float %x, 0.000000e+00 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's best to have minimal flags to show the behavior, not put all flags on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be easier to review this if you split the baseline tests into a separate commit, as per https://llvm.org/docs/InstCombineContributorGuide.html#precommit-tests
This is currently under development and needs some assistance/review to go about solving this.
This PR intends to solve this issue #64884.