-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Improve support for rsqrt.approx #89417
Conversation
@llvm/pr-subscribers-llvm-ir Author: Alex MacLean (AlexMaclean) ChangesComplete support for rsqrt.approx with rsqrt.approx.f64 (PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64). Additionally, add support for folding Full diff: https://github.com/llvm/llvm-project/pull/89417.diff 7 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606e2..0a9139e0062ba3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+ def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+ DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3ff8994602e16b..26b95b396cfebf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
#define DEBUG_TYPE "nvptx-isel"
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
+static cl::opt<bool>
+ DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
+ cl::desc("Disable reciprocal sqrt optimization"));
+
/// createNVPTXISelDag - This pass converts a legalized DAG into a
/// NVPTX-specific DAG, ready for instruction scheduling.
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -78,6 +82,8 @@ bool NVPTXDAGToDAGISel::useShortPointers() const {
return TM.useShortPointers();
}
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return !DisableRsqrtOpt; }
+
/// Select - Select instructions not customized! Used for
/// expanded, promoted and normal instructions.
void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 84c8432047ca31..c74f8fa884596a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -37,6 +37,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool allowFMA() const;
bool allowUnsafeFPMath() const;
bool useShortPointers() const;
+ bool doRsqrtOpt() const;
public:
static char ID;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index cd8546005c0289..0af7423dfd0b9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
def doF32FTZ : Predicate<"useF32FTZ()">;
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
def doMulWide : Predicate<"doMulWide">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0c53380a13e9b..05b22aa41198d1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
def INT_NVVM_RSQRT_APPROX_FTZ_F
: F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+ : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+ int_nvvm_rsqrt_approx_ftz_d>;
+
def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
//
// Add
//
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 00000000000000..11bd82a015ce12
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.sqrt.f32(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+ %sqrt = tail call float @llvm.sqrt.f32(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 00000000000000..c7367245c532e3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+ %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+ ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+ %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+ ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+ %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+ ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+ %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+ ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)
|
@@ -30,6 +30,10 @@ using namespace llvm; | |||
#define DEBUG_TYPE "nvptx-isel" | |||
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection" | |||
|
|||
static cl::opt<bool> | |||
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need someone familiar with FP nuances to take a look.
Enabling sqrt->rsqrt.approx by default sounds risky to me.
How do other targets handle sqrt optimizations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need someone familiar with FP nuances to take a look.
Any ideas who might have the necessary expertise?
Enabling sqrt->rsqrt.approx by default sounds risky to me.
By default such a transformation is disabled. We'll only do that transformation when -nvptx-prec-sqrtf32=0
has been specified. In these cases we'd lower sqrt
to sqrt.approx
before this change. Adding sqrt.approx
-> rsqrt.approx
seems safe to me, and FWIW it hasn't caused any issues with our internal testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patch sets DisableRsqrtOpt = false
by default, so doRsqrtOpt()
will return true.
I was looking at tests 3-6 below which optimized sqrt -> rqsrt.approx
It was not obvious that it's only in effect when -nvptx-prec-sqrtf32=0
is in effect. This is probably OK, as we're already promising an approximate result.
@@ -30,6 +30,10 @@ using namespace llvm; | |||
#define DEBUG_TYPE "nvptx-isel" | |||
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection" | |||
|
|||
static cl::opt<bool> | |||
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patch sets DisableRsqrtOpt = false
by default, so doRsqrtOpt()
will return true.
I was looking at tests 3-6 below which optimized sqrt -> rqsrt.approx
It was not obvious that it's only in effect when -nvptx-prec-sqrtf32=0
is in effect. This is probably OK, as we're already promising an approximate result.
@@ -30,6 +30,10 @@ using namespace llvm; | |||
#define DEBUG_TYPE "nvptx-isel" | |||
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection" | |||
|
|||
static cl::opt<bool> | |||
DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, a nit on the option name. Disabling a "disable-something" knob always strikes me as odd. Can we make it a positive control? E.g. -nvptx-approx-rsqrt
?
Naming is hard. :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, what about -nvptx-rsqrt-approx-opt
or -nvptx-rsqrt-approx-folding
? I think it would be good to make clear that this controls optimizing sqrt
to rsqrt
not whether rsqrt
will be emitted at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that it's a hidden option, I don't have particularly strong feelings about the name. -nvptx-rsqrt-approx-opt
works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I've updated the MR
✅ With the latest revision this PR passed the C/C++ code formatter. |
315e3ec
to
42055df
Compare
42055df
to
91b50c4
Compare
Complete support for rsqrt.approx with rsqrt.approx.f64 (PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64). Additionally, add support for folding
sqrt
intorsqrt
, with an optional flag to disable.