Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Improve support for rsqrt.approx #89417

Merged
merged 2 commits into from
Apr 23, 2024

Conversation

AlexMaclean
Copy link
Member

Complete support for rsqrt.approx with rsqrt.approx.f64 (PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64). Additionally, add support for folding sqrt into rsqrt, with an optional flag to disable.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-llvm-ir

Author: Alex MacLean (AlexMaclean)

Changes

Complete support for rsqrt.approx with rsqrt.approx.f64 (PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64). Additionally, add support for folding sqrt into rsqrt, with an optional flag to disable.


Full diff: https://github.com/llvm/llvm-project/pull/89417.diff

7 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+25)
  • (added) llvm/test/CodeGen/NVPTX/rsqrt-opt.ll (+75)
  • (added) llvm/test/CodeGen/NVPTX/rsqrt.ll (+35)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606e2..0a9139e0062ba3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
 
   def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+      DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3ff8994602e16b..26b95b396cfebf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-isel"
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
+static cl::opt<bool>
+    DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
+                    cl::desc("Disable reciprocal sqrt optimization"));
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -78,6 +82,8 @@ bool NVPTXDAGToDAGISel::useShortPointers() const {
   return TM.useShortPointers();
 }
 
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return !DisableRsqrtOpt; }
+
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
 void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 84c8432047ca31..c74f8fa884596a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -37,6 +37,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool allowFMA() const;
   bool allowUnsafeFPMath() const;
   bool useShortPointers() const;
+  bool doRsqrtOpt() const;
 
 public:
   static char ID;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index cd8546005c0289..0af7423dfd0b9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
 
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
 
 def doMulWide      : Predicate<"doMulWide">;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0c53380a13e9b..05b22aa41198d1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
 def INT_NVVM_RSQRT_APPROX_FTZ_F
   : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
     int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+  : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+    int_nvvm_rsqrt_approx_ftz_d>;
+
 def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
   Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
 def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
   Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
 
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
 //
 // Add
 //
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 00000000000000..11bd82a015ce12
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 00000000000000..c7367245c532e3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+  %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+  %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+  ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+  %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+  %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+  ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)

@@ -30,6 +30,10 @@ using namespace llvm;
#define DEBUG_TYPE "nvptx-isel"
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"

static cl::opt<bool>
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need someone familiar with FP nuances to take a look.

Enabling sqrt->rsqrt.approx by default sounds risky to me.

How do other targets handle sqrt optimizations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need someone familiar with FP nuances to take a look.

Any ideas who might have the necessary expertise?

Enabling sqrt->rsqrt.approx by default sounds risky to me.

By default such a transformation is disabled. We'll only do that transformation when -nvptx-prec-sqrtf32=0 has been specified. In these cases we'd lower sqrt to sqrt.approx before this change. Adding sqrt.approx -> rsqrt.approx seems safe to me, and FWIW it hasn't caused any issues with our internal testing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch sets DisableRsqrtOpt = false by default, so doRsqrtOpt() will return true.
I was looking at tests 3-6 below which optimized sqrt -> rqsrt.approx

It was not obvious that it's only in effect when -nvptx-prec-sqrtf32=0 is in effect. This is probably OK, as we're already promising an approximate result.

@@ -30,6 +30,10 @@ using namespace llvm;
#define DEBUG_TYPE "nvptx-isel"
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"

static cl::opt<bool>
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch sets DisableRsqrtOpt = false by default, so doRsqrtOpt() will return true.
I was looking at tests 3-6 below which optimized sqrt -> rqsrt.approx

It was not obvious that it's only in effect when -nvptx-prec-sqrtf32=0 is in effect. This is probably OK, as we're already promising an approximate result.

@@ -30,6 +30,10 @@ using namespace llvm;
#define DEBUG_TYPE "nvptx-isel"
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"

static cl::opt<bool>
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a nit on the option name. Disabling a "disable-something" knob always strikes me as odd. Can we make it a positive control? E.g. -nvptx-approx-rsqrt ?

Naming is hard. :-/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, what about -nvptx-rsqrt-approx-opt or -nvptx-rsqrt-approx-folding? I think it would be good to make clear that this controls optimizing sqrt to rsqrt not whether rsqrt will be emitted at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that it's a hidden option, I don't have particularly strong feelings about the name. -nvptx-rsqrt-approx-opt works for me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've updated the MR

Copy link

github-actions bot commented Apr 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean merged commit df60805 into llvm:main Apr 23, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants