-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GlobalISel needs fdiv 1 / sqrt(x) to rsq combine #78673
Conversation
@llvm/pr-subscribers-llvm-globalisel @llvm/pr-subscribers-backend-amdgpu Author: Nick Anderson (nickleus27) ChangesFixes #64743 @arsenm @Pierre-vh Could you guys review and let me know if I am headed in the right direction.
Full diff: https://github.com/llvm/llvm-project/pull/78673.diff 3 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index ea6ed322e9b1927..6ffb0842db3e4e6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -495,6 +495,12 @@ m_GFMul(const LHS &L, const RHS &R) {
return BinaryOp_match<LHS, RHS, TargetOpcode::G_FMUL, true>(L, R);
}
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>
+m_GFDiv(const LHS &L, const RHS &R) {
+ return BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>(L, R);
+}
+
template <typename LHS, typename RHS>
inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FSUB, false>
m_GFSub(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index b9411e2052120d8..f26fb12dc1149f0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
+def fdiv_1_by_sqrt_to_rsq : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (wip_match_opcode G_FSQRT, G_FDIV):$root,
+ [{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
@@ -156,7 +161,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
"AMDGPUPostLegalizerCombinerImpl",
[all_combines, gfx6gfx7_combines, gfx8_combines,
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
- rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
+ rcp_sqrt_to_rsq, fdiv_1_by_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
let CombineAllMethodName = "tryCombineAllImpl";
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index a1c34e92a57f356..9cd8436c188dc47 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
matchRcpSqrtToRsq(MachineInstr &MI,
std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+ bool matchFDivSqrt(MachineInstr &MI,
+ std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+
// FIXME: Should be able to have 2 separate matchdatas rather than custom
// struct boilerplate.
struct CvtF32UByteMatchInfo {
@@ -334,6 +337,59 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
return false;
}
+bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
+ MachineInstr &MI,
+ std::function<void(MachineIRBuilder &)> &MatchInfo) const {
+
+ // TODO: Can I match fdiv 1.0 / sqrt(x) from here?
+ // My apologies, this code is still a mess. Trying to figure out
+ // what value MI should hold when getting to this point
+
+ auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+ if (!MI.getFlag(MachineInstr::FmContract))
+ return nullptr;
+ MachineInstr *SqrtSrcMI = nullptr;
+ auto Match =
+ mi_match(MI.getOperand(0).getReg(), MRI, m_GFSqrt(m_MInstr(SqrtSrcMI)));
+ (void)Match;
+ return SqrtSrcMI;
+ };
+
+ // Do I need to match write a matcher for %one:_(s16) = G_FCONSTANT half 1.0
+ // ??
+
+ auto getFdivSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+ if (!MI.getFlag(MachineInstr::FmContract))
+ return nullptr;
+
+ MachineInstr *FDivSrcMI = nullptr;
+ Register One;
+ auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
+ m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
+ // Not sure how to check for FDiv operancd has a 1.0 value ?
+ if (!MI.getOperand(1).isFPImm()) {
+ return nullptr;
+ }
+ if (!MI.getOperand(1).getFPImm()->isOneValue()) {
+ return nullptr;
+ }
+ (void)Match;
+ return FDivSrcMI;
+ };
+
+ MachineInstr *FDivSrcMI = nullptr, *SqrtSrcMI = nullptr;
+ if ((SqrtSrcMI = getSqrtSrc(MI)) && (FDivSrcMI = getFdivSrc(*SqrtSrcMI))) {
+ MatchInfo = [SqrtSrcMI, &MI](MachineIRBuilder &B) {
+ B.buildIntrinsic(Intrinsic::amdgcn_rsq, {MI.getOperand(0)})
+ .addUse(SqrtSrcMI->getOperand(0).getReg())
+ .setMIFlags(MI.getFlags());
+ };
+ return true;
+ }
+
+ return false;
+}
+
bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
Register SrcReg = MI.getOperand(1).getReg();
|
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule< | |||
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]), | |||
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>; | |||
|
|||
def fdiv_1_by_sqrt_to_rsq : GICombineRule< | |||
(defs root:$root, build_fn_matchinfo:$matchinfo), | |||
(match (wip_match_opcode G_FSQRT, G_FDIV):$root, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove the G_FSQRT
.
bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt( | ||
MachineInstr &MI, | ||
std::function<void(MachineIRBuilder &)> &MatchInfo) const { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(MI.getOpode() == TargetOpcode::G_FDIV && "expected fdiv");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're indeed trying to match something like this:
%sqrt:_(s16) = contract G_FSQRT %x
%one:_(s16) = G_FCONSTANT half 1.0
%rsq:_(s16) = contract G_FDIV %one, %sqrt
@arsenm posted a few examples in the ticket. Please add them all to a .mir
file in test/CodeGen/AMDGPU/GlobalISel
and use update_mir_test_checks
to generate check lines. This allows you to test your changes because you can just rebuild llc, regen the checks and see what your new combine does.
When adding things to the combiner, I would recommend doing it in a test-driven way. It's very helpful IMO because you can verify changes fast: implement something, build llc, regen the lines
Lastly, once this all looks good, don't forget to run check-llvm-codegen
to catch issues :)
(defs root:$root, build_fn_matchinfo:$matchinfo), | ||
(match (wip_match_opcode G_FSQRT, G_FDIV):$root, | ||
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do this fully in MIR patterns :) https://llvm.org/docs/GlobalISel/MIRPatterns.html
We just can't match FP constants yet so that needs to be done in C++
I also don't know if we want this for half only or for all FP types. If it's the former, you'd also need to check that the type of $dst
is LLT::scalar(16)
.
(defs root:$root, build_fn_matchinfo:$matchinfo), | |
(match (wip_match_opcode G_FSQRT, G_FDIV):$root, | |
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]), | |
(defs root:$dst), | |
(match (G_FSQRT $sqrt, $x), | |
(G_FCONSTANT $one, $fpimm), | |
(G_FDIV $dst, $sqrt, $fpimm, (MIFlags FmContract)), | |
[{ return ${fpimm}.getOperand(1).isExactlyValue(1.0); }]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Pierre-vh Thank you for your review. I am using this example, and I am making some progress now. However, one error I am getting that I do not understand is ('G_FCONSTANT') is unreachable from the pattern root!
. Any ideas why I would be getting this error? I am not seeing why this pattern would not be reachable.
(defs root:$root, build_fn_matchinfo:$matchinfo), | ||
(match (wip_match_opcode G_FSQRT, G_FDIV):$root, | ||
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]), | ||
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't create intrinsics in MIR patterns yet (I always forget to work on it haha), so that needs to stay in C++.
You will need to create your own apply function with the contents of the BuildFn you have below.
- You won't need matchdata, you can directly pass
${x}.getReg()
-${x}
is a substitution that'll be replaced with a reference to the operand$x
matched - You will need to pass the
G_FDIV
inst so you can delete it. You can do that by adding a name to that pattern (just like the wip_match_opcode earlier was named $root). It'll be aMachineInstr *
.
// TODO: Can I match fdiv 1.0 / sqrt(x) from here? | ||
// My apologies, this code is still a mess. Trying to figure out | ||
// what value MI should hold when getting to this point |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you do MIR patterns this code goes away but I'll review it:
Trying to figure out what value MI should hold when getting to this point
MI here should be a MachineInstr for the root of the match. There was a mistake in the TableGen: (wip_match_opcode G_FSQRT, G_FDIV)
- this should only contain the root instruction you're trying to match, so G_FDIV
.
If you use that, the combiner will call your function every time it sees a G_FDIV
, and MI
would be any G_FDIV
in the code
Now your matcher code needs to:
- Check that
MI
has FmContract - Check that the first operand of
MI
is 1.0 - Check that the second operand of
MI
is an instruction that's aG_FSQRT
and also has FmContract.
GlobalISel/Utils.h
has helpers for the last 2 points - to get a ConstantFP
and to find a Def of an instruction, ignoring any COPY
in the process. You could also use mi_match
but it's a bit more complicated
Register One; | ||
auto Match = mi_match(MI.getOperand(0).getReg(), MRI, | ||
m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI))); | ||
// Not sure how to check for FDiv operancd has a 1.0 value ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// My apologies, this code is still a mess. Trying to figure out | ||
// what value MI should hold when getting to this point | ||
|
||
auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does not compile, but ...
std::optional<FPValueAndVReg > One;
Register SqrtReg;
if (!mi_match(MI.getOperand(1).getReg(), MRI, m_GFCstOrSplat(One)) || !mi_match(MI.getOperand(2).getReg(), MIR, m_OneNonDBGUse (m_GFSqrt(m_Reg(SqrtReg)))
return false;
Slightly off-topic, but why do we need to match against 1.0? Wouldn't it be better to combine |
533b424
to
59a7d7d
Compare
We're missing folds for that, at least for f64 last I looked. It was on my todo list for math stuff but I never got around to it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very good, thanks - just a few nits left to sort out 😄
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0) | ||
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0) | |
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]), | |
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0) | |
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]), |
@@ -334,6 +336,19 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq( | |||
return false; | |||
} | |||
|
|||
void AMDGPUPostLegalizerCombinerImpl::applyOneFDivSqrtToRsq( | |||
MachineInstr &MI, const Register &X) const { | |||
// B.setInstrAndDebugLoc(MI); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should already be set so you can delete it - the combiner will automatically set the insertion point to the MI being looked at
|
||
Register Dst = MI.getOperand(0).getReg(); | ||
|
||
B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst})) | |
B.buildIntrinsic(Intrinsic::amdgcn_rsq, {Dst}) |
This should work I think, maybe you don't even need the {}
- I don't remember if implicit ctor works in this case
|
||
B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst})) | ||
.addUse(X) | ||
.setMIFlags(MI.getFlags()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me to copy the flags. Contracting again should be allowed (but I don't know if it would ever happen in practice for rsq). Other flags like nnan and ninf could perhaps be useful too. If you wanted to get really clever you should union (or intersect?) the flags from the two operations being contracted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed adding the flags since the initial review. Let me know what the final decision is so I can implement. Since I am trying to learn I would like to ask, what would be the potential harm of adding the flags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copying the flag sounds good then :)
Flags are optimization hints to enable assumptions, and if those assumptions are broken - or the flag is added when an assumption isn't verified - you can end up with very weird results https://llvm.org/docs/LangRef.html#fastmath
e.g.: When you have contract
, you're telling the compiler "in this case, you can assume the result of a fused operation will be the same as the result of two distinct operation". This matters because a distinct multiply + add may round numbers twice - once after each operation, but a fused multiply-add (fma) may only round once. If the compiler knows the inputs of the operations don't care, then it can create a fma, if it doesn't know that, it needs to emit two instructions instead to ensure the result is correct
(G_FCONSTANT $one, $fpimm), | ||
(G_FDIV $dst, $one, $sqrt, (MIFlags FmContract)):$root, | ||
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0) | ||
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How it is correct to match -1.0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a novice, so I am asking in hopes of learning, but what would make this incorrect? This seems to work as far as matching MIR that contains -1/sqrt(x) pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converting 1/sqrt(x) to rsq(x) seems fine since that is the definition of rsq. Converting -1/sqrt(x) to rsq(x) seems wrong because that's just not what rsq does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that makes sense. Could we do -1/sqrt(x) to fneg(rsq(x))?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we do try to fold that in other places. AMDGPUCodeGenPrepare handles most of these folds, it just leaves the fastest math paths for codegen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we do -1/sqrt(x) to fneg(rsq(x))?
Yes, especially since fneg is cheap since it can be folded into lots of other (AMDGPU) instructions. But it does make me wonder again, if you're going to generalise, why not generalise to any value? I.e. y/sqrt(x) -> y*rsq(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we do -1/sqrt(x) to fneg(rsq(x))?
Yes, especially since fneg is cheap since it can be folded into lots of other (AMDGPU) instructions. But it does make me wonder again, if you're going to generalise, why not generalise to any value? I.e. y/sqrt(x) -> y*rsq(x)
I agree generalizing makes sense if you're going to handle more than 1/sqrt - you can simplify the apply function at least to just emit y * rsq(x)
and let other combine eliminate a 1*rsq
, or transform -1*rsq
into fneg(rsq)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree generalizing makes sense if you're going to handle more than 1/sqrt - you can simplify the apply function at least to just emit
y * rsq(x)
and let other combine eliminate a1*rsq
, or transform-1*rsq
intofneg(rsq)
So should I re-implement this PR to match y/sqrt(x)
and have the apply function emit y * rsq(x)
and a second combine match 1 * rsq
& -1 * rsq
to rsq
& fneg(rsq)
respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right_identity_one_fp
will take care of the 1 * rsq
, it seems like we don't have one for -1.0
to G_FNEG
though so you can do it in a follow up patch if you want. X * 1
and X * -1
are generic simplifications that work for anything in X
So no, you don't need to reimplement much, just make the matcher more generic (don't check for +-1.0) and in the apply, you always emit rsq(x) * y
. Prefer to emit the y
on the rhs, that way if it's a constant it's already the canonical form for mul (constants are always on the RHS for commutative ops IIRC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I would like to add fold for -1.0
to G_FNEG
in a follow up patch.
%x:_(s16) = G_TRUNC %0:_(s32) | ||
%sqrt:_(s16) = contract G_FSQRT %x | ||
%three:_(s16) = G_FCONSTANT half 3.0 | ||
%rsq:_(s16) = contract G_FDIV %three, %sqrt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost done, just needs a bit more testing
It needs some tests for 1/sqrt and -1/sqrt (ideally for every type too) so we can see what happens in those case, and it might be nice to also have a bit more variance in the numerators - try some values like +-0.5, or bigger numbers like 10?
I'm wondering if 0.5 gets converted to rsq/2
You can just change all these tests to use + or -1, and then add the few extra tests with different numerators at the bottom, no need to have those for every type I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added in the suggested tests. I can add more if it is better to be on the safe side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's enough tests, IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, but wait for @arsenm to confirm before landing as he created the issue
@Pierre-vh I am waiting for @arsenm's approval since he opened the issue. But, do you know why I do not have the ability to merge the PR (because I don't have write access)? I do not see the "Merge Pull Request" button. And, do you know if there is something I can do to make the button available to me? That way I do not have to ask others to merge for me? I attached a screen shot of where I would expect to see the button. |
def fdiv_by_sqrt_to_rsq : GICombineRule< | ||
(defs root:$root), | ||
(match (G_FSQRT $sqrt, $x, (MIFlags FmContract)), | ||
(G_FDIV $dst, $y, $sqrt, (MIFlags FmContract)):$root), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs more flags. Without the denormal correction code, this isn't more precise. Also switches to dropping denormals and 1ulp. You can refer to the dag version for the exact conditions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also missing hasOneUse check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually contract is adequate for the f16 case, f32 is the complicated one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not seeing why / where we need hasOneUse check? Do we need to check for the FDIV operands?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check that G_FSQRT
has one user, so we don't do the combine if it means the G_FSQRT
stays.
(also means you will need a test with a G_FSQRT
that has multiple users)
To check hasOneUse
, you can use a C++ predicate on top of the pattern that calls MRI.hasOneUse
on ${sqrt}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reply.
I assume you don't have commit access? Then you won't be able to merge yourself, but we can do it. |
I think this is close. I forgot that I remove the codegen support for rsq in the DAG over the summer. I think this patch only needs to handle the f16 part, and not introduce f32/f64 rsq. We do that in CodeGenPrepare |
// What about v_rsq_f64? - Is UnsafeFPMath sufficient to do this for f64? The | ||
// maximum ULP error seems really high at 2^29 ULP. | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't allow f64. We've never done it, and I believe we need should add extra correction code when we do use it. It's been on my todo list for a long time to try to make use of it. Currently the library code has an expansion we should move into the compiler
@@ -33,12 +33,12 @@ def rcp_sqrt_to_rsq : GICombineRule< | |||
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]), | |||
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>; | |||
|
|||
def fdiv_by_sqrt_to_rsq : GICombineRule< | |||
def fdiv_by_sqrt_to_rsq_f16 : GICombineRule< | |||
(defs root:$root), | |||
(match (G_FSQRT $sqrt, $x, (MIFlags FmContract)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f16:sqrt should also work, I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we also want the f16 vectors to work too (although that doesn't work since we're emitting the intrinsic and post legalize)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So don't worry about f16 vectors? just f16:sqrt is okay?
// f32/f64 rsq is handled in AMDGPUCodeGenPrepare | ||
// only match if operand type is f16 | ||
// v_rsq_f16 supports denormals and 0.51ulp. | ||
if (DstTy == LLT::scalar(16)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can move the type check into the pattern
MachineInstr &MI) const { | ||
Register Dst = MI.getOperand(0).getReg(); | ||
Register Sqrt = MI.getOperand(2).getReg(); | ||
LLT DstTy = MRI.getType(Dst); | ||
const MachineFunction &MF = B.getMF(); | ||
bool AllowInaccurateRsq = | ||
MI.getFlag(MachineInstr::FmAfn) || MF.getTarget().Options.UnsafeFPMath; | ||
if (!MRI.hasOneUse(Sqrt)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we had combiner magic for hasOneUse but I'm not finding it. Also, this should really have been hasOneNonDbgUse
$vgpr0 = COPY %ext | ||
|
||
... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing tests with the missing contract on each inst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably doesn't really matter, eventually the real testing should be the end to end IR tests shared with the DAG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add in the tests without contract.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably doesn't really matter, eventually the real testing should be the end to end IR tests shared with the DAG
What is the end to end IR tests shared with the DAG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. we have llvm/test/CodeGen/AMDGPU/rsq.f32.ll and a bunch of others. I was working on getting consistent coverage here, but not sure I ever got to f16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will work on adding rsq.16.ll and I will use rsq.32.ll as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see llvm/test/CodeGen/AMDGPU/fdiv.f16.ll has some rsq tests. Should I add more tests to this file or make a new file rsq.f16.ll?
@arsenm I pushed tests with missing contract instructions for f16 type in the .mir file. I looked into adding .ll file tests for f16 type rsq. I found f16 rsq tests in llvm/test/CodeGen/AMDGPU/fdiv.f16.ll which included tests for f16 type with and without contract. I am not sure what I would need to do to make the tests applicable to this patch. Any suggestions? |
If everyone approves could someone go ahead and commit for me. I do not have write access. Thank you. |
follow up patch to #78673 @Pierre-vh @jayfoad @arsenm Could you review when you have a chance.
Since #80526 landed first. I need to squash these fixup commits, and pull from upstream since the aforementioned commit will change the .mir in this patch. So once I update the .mir I will try back for review. |
432ad2d
to
ac88229
Compare
@arsenm can you land this for me if everything looks good? If not please let me know what needs to be improved. |
Fixes #64743
@arsenm @Pierre-vh Could you guys review and let me know if I am headed in the right direction.
%sqrt:_(s16) = contract G_FSQRT %x
%one:_(s16) = G_FCONSTANT half 1.0
%rsq:_(s16) = contract G_FDIV %one, %sqrt
?
AMDGPUCombine.td
match the above MIR and call the function I made calledmatchFDivSqrt
?AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt
would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?