Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GlobalISel needs fdiv 1 / sqrt(x) to rsq combine #78673

Merged
merged 1 commit into from
Feb 22, 2024

Conversation

nickleus27
Copy link
Contributor

@nickleus27 nickleus27 commented Jan 19, 2024

Fixes #64743

@arsenm @Pierre-vh Could you guys review and let me know if I am headed in the right direction.

  1. is the MIR I am trying to match against
    %sqrt:_(s16) = contract G_FSQRT %x
    %one:_(s16) = G_FCONSTANT half 1.0
    %rsq:_(s16) = contract G_FDIV %one, %sqrt
    ?
  2. Will the matcher in AMDGPUCombine.td match the above MIR and call the function I made called matchFDivSqrt?
  3. Any advice on what needs to be done in AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-amdgpu

Author: Nick Anderson (nickleus27)

Changes

Fixes #64743

@arsenm @Pierre-vh Could you guys review and let me know if I am headed in the right direction.

  1. is the MIR I am trying to match against
    %sqrt:_(s16) = contract G_FSQRT %x %one:_(s16) = G_FCONSTANT half 1.0 %rsq:_(s16) = contract G_FDIV %one, %sqrt ?
  2. Will the matcher in AMDGPUCombine.td match the above MIR and call the function I made called matchFDivSqrt?
  3. Any advice on what needs to be done in AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt would be appreciated. For example, what is the state of MI that is passed in? Is it a single instruction or is it a chain/tree of instructions?

Full diff: https://github.com/llvm/llvm-project/pull/78673.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+6)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCombine.td (+6-1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp (+56)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index ea6ed322e9b1927..6ffb0842db3e4e6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -495,6 +495,12 @@ m_GFMul(const LHS &L, const RHS &R) {
   return BinaryOp_match<LHS, RHS, TargetOpcode::G_FMUL, true>(L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>
+m_GFDiv(const LHS &L, const RHS &R) {
+  return BinaryOp_match<LHS, RHS, TargetOpcode::G_FDIV, true>(L, R);
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOp_match<LHS, RHS, TargetOpcode::G_FSUB, false>
 m_GFSub(const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
index b9411e2052120d8..f26fb12dc1149f0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td
@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
          [{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
 
+def fdiv_1_by_sqrt_to_rsq : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (wip_match_opcode G_FSQRT, G_FDIV):$root,
+         [{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
 def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
 
@@ -156,7 +161,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
   "AMDGPUPostLegalizerCombinerImpl",
   [all_combines, gfx6gfx7_combines, gfx8_combines,
    uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
-   rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
+   rcp_sqrt_to_rsq, fdiv_1_by_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
   let CombineAllMethodName = "tryCombineAllImpl";
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
index a1c34e92a57f356..9cd8436c188dc47 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
   matchRcpSqrtToRsq(MachineInstr &MI,
                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
 
+  bool matchFDivSqrt(MachineInstr &MI,
+                     std::function<void(MachineIRBuilder &)> &MatchInfo) const;
+
   // FIXME: Should be able to have 2 separate matchdatas rather than custom
   // struct boilerplate.
   struct CvtF32UByteMatchInfo {
@@ -334,6 +337,59 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
   return false;
 }
 
+bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
+    MachineInstr &MI,
+    std::function<void(MachineIRBuilder &)> &MatchInfo) const {
+
+  // TODO: Can I match fdiv 1.0 / sqrt(x) from here?
+  // My apologies, this code is still a mess. Trying to figure out
+  // what value MI should hold when getting to this point
+
+  auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+    MachineInstr *SqrtSrcMI = nullptr;
+    auto Match =
+        mi_match(MI.getOperand(0).getReg(), MRI, m_GFSqrt(m_MInstr(SqrtSrcMI)));
+    (void)Match;
+    return SqrtSrcMI;
+  };
+
+  // Do I need to match write a matcher for  %one:_(s16) = G_FCONSTANT half 1.0
+  // ??
+
+  auto getFdivSrc = [=](const MachineInstr &MI) -> MachineInstr * {
+    if (!MI.getFlag(MachineInstr::FmContract))
+      return nullptr;
+
+    MachineInstr *FDivSrcMI = nullptr;
+    Register One;
+    auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
+                          m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
+    // Not sure how to check for FDiv operancd has a 1.0 value ?
+    if (!MI.getOperand(1).isFPImm()) {
+      return nullptr;
+    }
+    if (!MI.getOperand(1).getFPImm()->isOneValue()) {
+      return nullptr;
+    }
+    (void)Match;
+    return FDivSrcMI;
+  };
+
+  MachineInstr *FDivSrcMI = nullptr, *SqrtSrcMI = nullptr;
+  if ((SqrtSrcMI = getSqrtSrc(MI)) && (FDivSrcMI = getFdivSrc(*SqrtSrcMI))) {
+    MatchInfo = [SqrtSrcMI, &MI](MachineIRBuilder &B) {
+      B.buildIntrinsic(Intrinsic::amdgcn_rsq, {MI.getOperand(0)})
+          .addUse(SqrtSrcMI->getOperand(0).getReg())
+          .setMIFlags(MI.getFlags());
+    };
+    return true;
+  }
+
+  return false;
+}
+
 bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
     MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
   Register SrcReg = MI.getOperand(1).getReg();

@@ -33,6 +33,11 @@ def rcp_sqrt_to_rsq : GICombineRule<
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;

def fdiv_1_by_sqrt_to_rsq : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_FSQRT, G_FDIV):$root,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the G_FSQRT.

bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrt(
MachineInstr &MI,
std::function<void(MachineIRBuilder &)> &MatchInfo) const {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert(MI.getOpode() == TargetOpcode::G_FDIV && "expected fdiv");

Copy link
Contributor

@Pierre-vh Pierre-vh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're indeed trying to match something like this:

%sqrt:_(s16) = contract G_FSQRT %x
%one:_(s16) = G_FCONSTANT half 1.0
%rsq:_(s16) = contract G_FDIV %one, %sqrt

@arsenm posted a few examples in the ticket. Please add them all to a .mir file in test/CodeGen/AMDGPU/GlobalISel and use update_mir_test_checks to generate check lines. This allows you to test your changes because you can just rebuild llc, regen the checks and see what your new combine does.

When adding things to the combiner, I would recommend doing it in a test-driven way. It's very helpful IMO because you can verify changes fast: implement something, build llc, regen the lines

Lastly, once this all looks good, don't forget to run check-llvm-codegen to catch issues :)

Comment on lines 37 to 39
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_FSQRT, G_FDIV):$root,
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do this fully in MIR patterns :) https://llvm.org/docs/GlobalISel/MIRPatterns.html
We just can't match FP constants yet so that needs to be done in C++

I also don't know if we want this for half only or for all FP types. If it's the former, you'd also need to check that the type of $dst is LLT::scalar(16).

Suggested change
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_FSQRT, G_FDIV):$root,
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
(defs root:$dst),
(match (G_FSQRT $sqrt, $x),
(G_FCONSTANT $one, $fpimm),
(G_FDIV $dst, $sqrt, $fpimm, (MIFlags FmContract)),
[{ return ${fpimm}.getOperand(1).isExactlyValue(1.0); }]),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pierre-vh Thank you for your review. I am using this example, and I am making some progress now. However, one error I am getting that I do not understand is ('G_FCONSTANT') is unreachable from the pattern root!. Any ideas why I would be getting this error? I am not seeing why this pattern would not be reachable.

(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_FSQRT, G_FDIV):$root,
[{ return matchFDivSqrt(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't create intrinsics in MIR patterns yet (I always forget to work on it haha), so that needs to stay in C++.
You will need to create your own apply function with the contents of the BuildFn you have below.

  • You won't need matchdata, you can directly pass ${x}.getReg() - ${x} is a substitution that'll be replaced with a reference to the operand $x matched
  • You will need to pass the G_FDIV inst so you can delete it. You can do that by adding a name to that pattern (just like the wip_match_opcode earlier was named $root). It'll be a MachineInstr *.

Comment on lines 344 to 346
// TODO: Can I match fdiv 1.0 / sqrt(x) from here?
// My apologies, this code is still a mess. Trying to figure out
// what value MI should hold when getting to this point
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do MIR patterns this code goes away but I'll review it:

Trying to figure out what value MI should hold when getting to this point

MI here should be a MachineInstr for the root of the match. There was a mistake in the TableGen: (wip_match_opcode G_FSQRT, G_FDIV) - this should only contain the root instruction you're trying to match, so G_FDIV.

If you use that, the combiner will call your function every time it sees a G_FDIV, and MI would be any G_FDIV in the code

Now your matcher code needs to:

  • Check that MI has FmContract
  • Check that the first operand of MI is 1.0
  • Check that the second operand of MI is an instruction that's a G_FSQRT and also has FmContract.

GlobalISel/Utils.h has helpers for the last 2 points - to get a ConstantFP and to find a Def of an instruction, ignoring any COPY in the process. You could also use mi_match but it's a bit more complicated

Register One;
auto Match = mi_match(MI.getOperand(0).getReg(), MRI,
m_GFDiv(m_Reg(One), m_MInstr(FDivSrcMI)));
// Not sure how to check for FDiv operancd has a 1.0 value ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// My apologies, this code is still a mess. Trying to figure out
// what value MI should hold when getting to this point

auto getSqrtSrc = [=](const MachineInstr &MI) -> MachineInstr * {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not compile, but ...

std::optional<FPValueAndVReg > One;
Register SqrtReg;
if (!mi_match(MI.getOperand(1).getReg(), MRI, m_GFCstOrSplat(One)) || !mi_match(MI.getOperand(2).getReg(), MIR, m_OneNonDBGUse (m_GFSqrt(m_Reg(SqrtReg)))
  return false;

@jayfoad
Copy link
Contributor

jayfoad commented Jan 19, 2024

  1. is the MIR I am trying to match against
    %sqrt:_(s16) = contract G_FSQRT %x
    %one:_(s16) = G_FCONSTANT half 1.0
    %rsq:_(s16) = contract G_FDIV %one, %sqrt
    ?

Slightly off-topic, but why do we need to match against 1.0? Wouldn't it be better to combine y/sqrt(x) into y*rsq(x) for all y, and then later optimize away the multiply if y happened to be 1.0?

@arsenm
Copy link
Contributor

arsenm commented Jan 23, 2024

  1. is the MIR I am trying to match against
    %sqrt:_(s16) = contract G_FSQRT %x
    %one:_(s16) = G_FCONSTANT half 1.0
    %rsq:_(s16) = contract G_FDIV %one, %sqrt
    ?

Slightly off-topic, but why do we need to match against 1.0? Wouldn't it be better to combine y/sqrt(x) into y*rsq(x) for all y, and then later optimize away the multiply if y happened to be 1.0?

We're missing folds for that, at least for f64 last I looked. It was on my todo list for math stuff but I never got around to it

@nickleus27 nickleus27 marked this pull request as ready for review January 24, 2024 05:31
Copy link
Contributor

@Pierre-vh Pierre-vh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very good, thanks - just a few nits left to sort out 😄

Comment on lines 41 to 42
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0)
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

Suggested change
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0)
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]),
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0)
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]),

@@ -334,6 +336,19 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
return false;
}

void AMDGPUPostLegalizerCombinerImpl::applyOneFDivSqrtToRsq(
MachineInstr &MI, const Register &X) const {
// B.setInstrAndDebugLoc(MI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should already be set so you can delete it - the combiner will automatically set the insertion point to the MI being looked at


Register Dst = MI.getOperand(0).getReg();

B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst}))
B.buildIntrinsic(Intrinsic::amdgcn_rsq, {Dst})

This should work I think, maybe you don't even need the {} - I don't remember if implicit ctor works in this case


B.buildIntrinsic(Intrinsic::amdgcn_rsq, ArrayRef<Register>({Dst}))
.addUse(X)
.setMIFlags(MI.getFlags());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure at all if that's needed (or if it's even correct - could we contract two operations again after applying contract already?)
cc @arsenm / @jayfoad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me to copy the flags. Contracting again should be allowed (but I don't know if it would ever happen in practice for rsq). Other flags like nnan and ninf could perhaps be useful too. If you wanted to get really clever you should union (or intersect?) the flags from the two operations being contracted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed adding the flags since the initial review. Let me know what the final decision is so I can implement. Since I am trying to learn I would like to ask, what would be the potential harm of adding the flags?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying the flag sounds good then :)

Flags are optimization hints to enable assumptions, and if those assumptions are broken - or the flag is added when an assumption isn't verified - you can end up with very weird results https://llvm.org/docs/LangRef.html#fastmath

e.g.: When you have contract, you're telling the compiler "in this case, you can assume the result of a fused operation will be the same as the result of two distinct operation". This matters because a distinct multiply + add may round numbers twice - once after each operation, but a fused multiply-add (fma) may only round once. If the compiler knows the inputs of the operations don't care, then it can create a fma, if it doesn't know that, it needs to emit two instructions instead to ensure the result is correct

(G_FCONSTANT $one, $fpimm),
(G_FDIV $dst, $one, $sqrt, (MIFlags FmContract)):$root,
[{ return ${fpimm}.getFPImm()->isExactlyValue(1.0)
|| ${fpimm}.getFPImm()->isExactlyValue(-1.0); }]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How it is correct to match -1.0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a novice, so I am asking in hopes of learning, but what would make this incorrect? This seems to work as far as matching MIR that contains -1/sqrt(x) pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting 1/sqrt(x) to rsq(x) seems fine since that is the definition of rsq. Converting -1/sqrt(x) to rsq(x) seems wrong because that's just not what rsq does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that makes sense. Could we do -1/sqrt(x) to fneg(rsq(x))?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do try to fold that in other places. AMDGPUCodeGenPrepare handles most of these folds, it just leaves the fastest math paths for codegen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do -1/sqrt(x) to fneg(rsq(x))?

Yes, especially since fneg is cheap since it can be folded into lots of other (AMDGPU) instructions. But it does make me wonder again, if you're going to generalise, why not generalise to any value? I.e. y/sqrt(x) -> y*rsq(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do -1/sqrt(x) to fneg(rsq(x))?

Yes, especially since fneg is cheap since it can be folded into lots of other (AMDGPU) instructions. But it does make me wonder again, if you're going to generalise, why not generalise to any value? I.e. y/sqrt(x) -> y*rsq(x)

I agree generalizing makes sense if you're going to handle more than 1/sqrt - you can simplify the apply function at least to just emit y * rsq(x) and let other combine eliminate a 1*rsq, or transform -1*rsq into fneg(rsq)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree generalizing makes sense if you're going to handle more than 1/sqrt - you can simplify the apply function at least to just emit y * rsq(x) and let other combine eliminate a 1*rsq, or transform -1*rsq into fneg(rsq)

So should I re-implement this PR to match y/sqrt(x) and have the apply function emit y * rsq(x) and a second combine match 1 * rsq & -1 * rsq to rsq & fneg(rsq) respectively?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right_identity_one_fp will take care of the 1 * rsq, it seems like we don't have one for -1.0 to G_FNEG though so you can do it in a follow up patch if you want. X * 1 and X * -1 are generic simplifications that work for anything in X

So no, you don't need to reimplement much, just make the matcher more generic (don't check for +-1.0) and in the apply, you always emit rsq(x) * y. Prefer to emit the y on the rhs, that way if it's a constant it's already the canonical form for mul (constants are always on the RHS for commutative ops IIRC)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would like to add fold for -1.0 to G_FNEG in a follow up patch.

%x:_(s16) = G_TRUNC %0:_(s32)
%sqrt:_(s16) = contract G_FSQRT %x
%three:_(s16) = G_FCONSTANT half 3.0
%rsq:_(s16) = contract G_FDIV %three, %sqrt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done, just needs a bit more testing

It needs some tests for 1/sqrt and -1/sqrt (ideally for every type too) so we can see what happens in those case, and it might be nice to also have a bit more variance in the numerators - try some values like +-0.5, or bigger numbers like 10?
I'm wondering if 0.5 gets converted to rsq/2

You can just change all these tests to use + or -1, and then add the few extra tests with different numerators at the bottom, no need to have those for every type I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added in the suggested tests. I can add more if it is better to be on the safe side?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's enough tests, IMO

@Pierre-vh Pierre-vh requested a review from arsenm January 30, 2024 09:04
Copy link
Contributor

@Pierre-vh Pierre-vh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, but wait for @arsenm to confirm before landing as he created the issue

@nickleus27
Copy link
Contributor Author

nickleus27 commented Feb 3, 2024

@Pierre-vh I am waiting for @arsenm's approval since he opened the issue. But, do you know why I do not have the ability to merge the PR (because I don't have write access)? I do not see the "Merge Pull Request" button. And, do you know if there is something I can do to make the button available to me? That way I do not have to ask others to merge for me? I attached a screen shot of where I would expect to see the button.
Screenshot 2024-02-03 at 2 23 07 PM

def fdiv_by_sqrt_to_rsq : GICombineRule<
(defs root:$root),
(match (G_FSQRT $sqrt, $x, (MIFlags FmContract)),
(G_FDIV $dst, $y, $sqrt, (MIFlags FmContract)):$root),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more flags. Without the denormal correction code, this isn't more precise. Also switches to dropping denormals and 1ulp. You can refer to the dag version for the exact conditions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also missing hasOneUse check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually contract is adequate for the f16 case, f32 is the complicated one

Copy link
Contributor Author

@nickleus27 nickleus27 Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not seeing why / where we need hasOneUse check? Do we need to check for the FDIV operands?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check that G_FSQRT has one user, so we don't do the combine if it means the G_FSQRT stays.
(also means you will need a test with a G_FSQRT that has multiple users)

To check hasOneUse, you can use a C++ predicate on top of the pattern that calls MRI.hasOneUse on ${sqrt}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply.

@Pierre-vh
Copy link
Contributor

@Pierre-vh I am waiting for @arsenm's approval since he opened the issue. But, do you know why I do not have the ability to merge the PR (because I don't have write access)? I do not see the "Merge Pull Request" button. And, do you know if there is something I can do to make the button available to me? That way I do not have to ask others to merge for me? I attached a screen shot of where I would expect to see the button. Screenshot 2024-02-03 at 2 23 07 PM

I assume you don't have commit access? Then you won't be able to merge yourself, but we can do it.
Once you have merged a few patches, you can ask for commit access: https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access

@arsenm
Copy link
Contributor

arsenm commented Feb 6, 2024

I think this is close. I forgot that I remove the codegen support for rsq in the DAG over the summer. I think this patch only needs to handle the f16 part, and not introduce f32/f64 rsq. We do that in CodeGenPrepare

Comment on lines 360 to 362
// What about v_rsq_f64? - Is UnsafeFPMath sufficient to do this for f64? The
// maximum ULP error seems really high at 2^29 ULP.
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't allow f64. We've never done it, and I believe we need should add extra correction code when we do use it. It's been on my todo list for a long time to try to make use of it. Currently the library code has an expansion we should move into the compiler

@@ -33,12 +33,12 @@ def rcp_sqrt_to_rsq : GICombineRule<
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;

def fdiv_by_sqrt_to_rsq : GICombineRule<
def fdiv_by_sqrt_to_rsq_f16 : GICombineRule<
(defs root:$root),
(match (G_FSQRT $sqrt, $x, (MIFlags FmContract)),
Copy link
Contributor

@arsenm arsenm Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f16:sqrt should also work, I think

Copy link
Contributor

@arsenm arsenm Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we also want the f16 vectors to work too (although that doesn't work since we're emitting the intrinsic and post legalize)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So don't worry about f16 vectors? just f16:sqrt is okay?

// f32/f64 rsq is handled in AMDGPUCodeGenPrepare
// only match if operand type is f16
// v_rsq_f16 supports denormals and 0.51ulp.
if (DstTy == LLT::scalar(16)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move the type check into the pattern

MachineInstr &MI) const {
Register Dst = MI.getOperand(0).getReg();
Register Sqrt = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(Dst);
const MachineFunction &MF = B.getMF();
bool AllowInaccurateRsq =
MI.getFlag(MachineInstr::FmAfn) || MF.getTarget().Options.UnsafeFPMath;
if (!MRI.hasOneUse(Sqrt)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had combiner magic for hasOneUse but I'm not finding it. Also, this should really have been hasOneNonDbgUse

$vgpr0 = COPY %ext

...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing tests with the missing contract on each inst

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably doesn't really matter, eventually the real testing should be the end to end IR tests shared with the DAG

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add in the tests without contract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably doesn't really matter, eventually the real testing should be the end to end IR tests shared with the DAG

What is the end to end IR tests shared with the DAG?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. we have llvm/test/CodeGen/AMDGPU/rsq.f32.ll and a bunch of others. I was working on getting consistent coverage here, but not sure I ever got to f16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will work on adding rsq.16.ll and I will use rsq.32.ll as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see llvm/test/CodeGen/AMDGPU/fdiv.f16.ll has some rsq tests. Should I add more tests to this file or make a new file rsq.f16.ll?

@nickleus27
Copy link
Contributor Author

@arsenm I pushed tests with missing contract instructions for f16 type in the .mir file. I looked into adding .ll file tests for f16 type rsq. I found f16 rsq tests in llvm/test/CodeGen/AMDGPU/fdiv.f16.ll which included tests for f16 type with and without contract. I am not sure what I would need to do to make the tests applicable to this patch. Any suggestions?

@nickleus27
Copy link
Contributor Author

If everyone approves could someone go ahead and commit for me. I do not have write access. Thank you.

jayfoad pushed a commit that referenced this pull request Feb 21, 2024
follow up patch to #78673

@Pierre-vh @jayfoad @arsenm Could you review when you have a chance.
@nickleus27
Copy link
Contributor Author

Since #80526 landed first. I need to squash these fixup commits, and pull from upstream since the aforementioned commit will change the .mir in this patch. So once I update the .mir I will try back for review.

@nickleus27
Copy link
Contributor Author

@arsenm can you land this for me if everything looks good? If not please let me know what needs to be improved.

@Pierre-vh Pierre-vh merged commit 8bd327d into llvm:main Feb 22, 2024
3 of 4 checks passed
@nickleus27 nickleus27 deleted the rsqCombine branch February 22, 2024 08:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GlobalISel needs fdiv 1 / sqrt(x) to rsq combine
6 participants