From 6eab9dd7f01e6cad9f1a93bd52e4c6e7b4c3c1fa Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Mon, 8 Jan 2024 15:17:00 -0800 Subject: [PATCH] [NVPTX] remove incorrect NVPTX intrinsic transformations (#76870) `nvvm_fabs_f` `nvvm_fabs_ftz_f` Unfortunately, llvm fabs is not equivalent to these intrinsics since llvm fabs is defined to only set the sign bit to zero while these can also flush subnormal inputs and modify NaNs. `nvvm_round_d` `nvvm_round_f` `nvvm_round_ftz_f` llvm.nvvm.round uses RNI, while llvm.round codegens to RZI. LLVM defines llvm.round to use the same rounding as libm `round[f]()`, which is not necessary the same as how we define llvm.nvvm.round. `nvvm_sqrt_rn_f` `nvvm_sqrt_rn_ftz_f` sqrt may be lowered to a less precise version of sqrt, such as sqrt.approx in NVPTX depending on factors such as the value of -nvptx-prec-sqrtf32. These intrinsics should always become the corresponding NVPTX instructions. `nvvm_add_rn_d` `nvvm_add_rn_f` `nvvm_add_rn_ftz_f` `nvvm_mul_rn_d` `nvvm_mul_rn_f` `nvvm_mul_rn_ftz_f` These nvvm intrinsics have an explicitly specified rounding mode (.rn). They should always be lowered to a PTX instruction with the same explicit rounding mode. Converting to fmul and fadd instructions result in the PTX instructions without rounding modes specified. This can cause issue because: > An add [or mul] instruction with no rounding modifier defaults to round-to-nearest-even and may be optimized aggressively by the code optimizer. In particular, mul/add sequences with no rounding modifiers may be optimized to use fused-multiply-add instructions on the target device. `nvvm_div_rn_f` `nvvm_div_rn_ftz_f` `nvvm_rcp_rn_f` `nvvm_rcp_rn_ftz_f` fdiv may be lowered to a less precise version of div, such as div.full in NVPTX depending on factors such as the value of -nvptx-prec-divf32. These intrinsics should always become the corresponding NVPTX instructions. --- .../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 34 ------------- .../InstCombine/NVPTX/nvvm-intrins.ll | 48 +++++++------------ 2 files changed, 17 insertions(+), 65 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index c73721da46e35..7aa63f9fc0c96 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -180,10 +180,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { return {Intrinsic::ceil, FTZ_MustBeOn}; case Intrinsic::nvvm_fabs_d: return {Intrinsic::fabs, FTZ_Any}; - case Intrinsic::nvvm_fabs_f: - return {Intrinsic::fabs, FTZ_MustBeOff}; - case Intrinsic::nvvm_fabs_ftz_f: - return {Intrinsic::fabs, FTZ_MustBeOn}; case Intrinsic::nvvm_floor_d: return {Intrinsic::floor, FTZ_Any}; case Intrinsic::nvvm_floor_f: @@ -264,12 +260,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { return {Intrinsic::minimum, FTZ_MustBeOff, true}; case Intrinsic::nvvm_fmin_ftz_nan_f16x2: return {Intrinsic::minimum, FTZ_MustBeOn, true}; - case Intrinsic::nvvm_round_d: - return {Intrinsic::round, FTZ_Any}; - case Intrinsic::nvvm_round_f: - return {Intrinsic::round, FTZ_MustBeOff}; - case Intrinsic::nvvm_round_ftz_f: - return {Intrinsic::round, FTZ_MustBeOn}; case Intrinsic::nvvm_sqrt_rn_d: return {Intrinsic::sqrt, FTZ_Any}; case Intrinsic::nvvm_sqrt_f: @@ -278,10 +268,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are // the versions with explicit ftz-ness. return {Intrinsic::sqrt, FTZ_Any}; - case Intrinsic::nvvm_sqrt_rn_f: - return {Intrinsic::sqrt, FTZ_MustBeOff}; - case Intrinsic::nvvm_sqrt_rn_ftz_f: - return {Intrinsic::sqrt, FTZ_MustBeOn}; case Intrinsic::nvvm_trunc_d: return {Intrinsic::trunc, FTZ_Any}; case Intrinsic::nvvm_trunc_f: @@ -316,24 +302,8 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { return {Instruction::UIToFP}; // NVVM intrinsics that map to LLVM binary ops. - case Intrinsic::nvvm_add_rn_d: - return {Instruction::FAdd, FTZ_Any}; - case Intrinsic::nvvm_add_rn_f: - return {Instruction::FAdd, FTZ_MustBeOff}; - case Intrinsic::nvvm_add_rn_ftz_f: - return {Instruction::FAdd, FTZ_MustBeOn}; - case Intrinsic::nvvm_mul_rn_d: - return {Instruction::FMul, FTZ_Any}; - case Intrinsic::nvvm_mul_rn_f: - return {Instruction::FMul, FTZ_MustBeOff}; - case Intrinsic::nvvm_mul_rn_ftz_f: - return {Instruction::FMul, FTZ_MustBeOn}; case Intrinsic::nvvm_div_rn_d: return {Instruction::FDiv, FTZ_Any}; - case Intrinsic::nvvm_div_rn_f: - return {Instruction::FDiv, FTZ_MustBeOff}; - case Intrinsic::nvvm_div_rn_ftz_f: - return {Instruction::FDiv, FTZ_MustBeOn}; // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but // need special handling. @@ -342,10 +312,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // as well. case Intrinsic::nvvm_rcp_rn_d: return {SPC_Reciprocal, FTZ_Any}; - case Intrinsic::nvvm_rcp_rn_f: - return {SPC_Reciprocal, FTZ_MustBeOff}; - case Intrinsic::nvvm_rcp_rn_ftz_f: - return {SPC_Reciprocal, FTZ_MustBeOn}; // We do not currently simplify intrinsics that give an approximate // answer. These include: diff --git a/llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll b/llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll index ca1a5237f905d..633aa43c4fc89 100644 --- a/llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll +++ b/llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll @@ -49,15 +49,13 @@ define double @fabs_double(double %a) #0 { } ; CHECK-LABEL: @fabs_float define float @fabs_float(float %a) #0 { -; NOFTZ: call float @llvm.fabs.f32 -; FTZ: call float @llvm.nvvm.fabs.f +; CHECK: call float @llvm.nvvm.fabs.f %ret = call float @llvm.nvvm.fabs.f(float %a) ret float %ret } ; CHECK-LABEL: @fabs_float_ftz define float @fabs_float_ftz(float %a) #0 { -; NOFTZ: call float @llvm.nvvm.fabs.ftz.f -; FTZ: call float @llvm.fabs.f32 +; CHECK: call float @llvm.nvvm.fabs.ftz.f %ret = call float @llvm.nvvm.fabs.ftz.f(float %a) ret float %ret } @@ -148,21 +146,19 @@ define float @fmin_float_ftz(float %a, float %b) #0 { ; CHECK-LABEL: @round_double define double @round_double(double %a) #0 { -; CHECK: call double @llvm.round.f64 +; CHECK: call double @llvm.nvvm.round.d %ret = call double @llvm.nvvm.round.d(double %a) ret double %ret } ; CHECK-LABEL: @round_float define float @round_float(float %a) #0 { -; NOFTZ: call float @llvm.round.f32 -; FTZ: call float @llvm.nvvm.round.f +; CHECK: call float @llvm.nvvm.round.f %ret = call float @llvm.nvvm.round.f(float %a) ret float %ret } ; CHECK-LABEL: @round_float_ftz define float @round_float_ftz(float %a) #0 { -; NOFTZ: call float @llvm.nvvm.round.ftz.f -; FTZ: call float @llvm.round.f32 +; CHECK: call float @llvm.nvvm.round.ftz.f %ret = call float @llvm.nvvm.round.ftz.f(float %a) ret float %ret } @@ -292,42 +288,38 @@ define float @test_ull2f(i64 %a) #0 { ; CHECK-LABEL: @test_add_rn_d define double @test_add_rn_d(double %a, double %b) #0 { -; CHECK: fadd +; CHECK: call double @llvm.nvvm.add.rn.d %ret = call double @llvm.nvvm.add.rn.d(double %a, double %b) ret double %ret } ; CHECK-LABEL: @test_add_rn_f define float @test_add_rn_f(float %a, float %b) #0 { -; NOFTZ: fadd -; FTZ: call float @llvm.nvvm.add.rn.f +; CHECK: call float @llvm.nvvm.add.rn.f %ret = call float @llvm.nvvm.add.rn.f(float %a, float %b) ret float %ret } ; CHECK-LABEL: @test_add_rn_f_ftz define float @test_add_rn_f_ftz(float %a, float %b) #0 { -; NOFTZ: call float @llvm.nvvm.add.rn.f -; FTZ: fadd +; CHECK: call float @llvm.nvvm.add.rn.ftz.f(float %a, float %b) %ret = call float @llvm.nvvm.add.rn.ftz.f(float %a, float %b) ret float %ret } ; CHECK-LABEL: @test_mul_rn_d define double @test_mul_rn_d(double %a, double %b) #0 { -; CHECK: fmul +; CHECK: call double @llvm.nvvm.mul.rn.d %ret = call double @llvm.nvvm.mul.rn.d(double %a, double %b) ret double %ret } ; CHECK-LABEL: @test_mul_rn_f define float @test_mul_rn_f(float %a, float %b) #0 { -; NOFTZ: fmul -; FTZ: call float @llvm.nvvm.mul.rn.f +; CHECK: call float @llvm.nvvm.mul.rn.f %ret = call float @llvm.nvvm.mul.rn.f(float %a, float %b) ret float %ret } ; CHECK-LABEL: @test_mul_rn_f_ftz define float @test_mul_rn_f_ftz(float %a, float %b) #0 { -; NOFTZ: call float @llvm.nvvm.mul.rn.f -; FTZ: fmul +; CHECK: call float @llvm.nvvm.mul.rn.ftz.f(float %a, float %b) %ret = call float @llvm.nvvm.mul.rn.ftz.f(float %a, float %b) ret float %ret } @@ -340,15 +332,13 @@ define double @test_div_rn_d(double %a, double %b) #0 { } ; CHECK-LABEL: @test_div_rn_f define float @test_div_rn_f(float %a, float %b) #0 { -; NOFTZ: fdiv -; FTZ: call float @llvm.nvvm.div.rn.f +; CHECK: call float @llvm.nvvm.div.rn.f %ret = call float @llvm.nvvm.div.rn.f(float %a, float %b) ret float %ret } ; CHECK-LABEL: @test_div_rn_f_ftz define float @test_div_rn_f_ftz(float %a, float %b) #0 { -; NOFTZ: call float @llvm.nvvm.div.rn.f -; FTZ: fdiv +; CHECK: call float @llvm.nvvm.div.rn.ftz.f(float %a, float %b) %ret = call float @llvm.nvvm.div.rn.ftz.f(float %a, float %b) ret float %ret } @@ -357,15 +347,13 @@ define float @test_div_rn_f_ftz(float %a, float %b) #0 { ; CHECK-LABEL: @test_rcp_rn_f define float @test_rcp_rn_f(float %a) #0 { -; NOFTZ: fdiv float 1.0{{.*}} %a -; FTZ: call float @llvm.nvvm.rcp.rn.f +; CHECK: call float @llvm.nvvm.rcp.rn.f %ret = call float @llvm.nvvm.rcp.rn.f(float %a) ret float %ret } ; CHECK-LABEL: @test_rcp_rn_f_ftz define float @test_rcp_rn_f_ftz(float %a) #0 { -; NOFTZ: call float @llvm.nvvm.rcp.rn.f -; FTZ: fdiv float 1.0{{.*}} %a +; CHECK: call float @llvm.nvvm.rcp.rn.ftz.f(float %a) %ret = call float @llvm.nvvm.rcp.rn.ftz.f(float %a) ret float %ret } @@ -385,15 +373,13 @@ define float @test_sqrt_f(float %a) #0 { } ; CHECK-LABEL: @test_sqrt_rn_f define float @test_sqrt_rn_f(float %a) #0 { -; NOFTZ: call float @llvm.sqrt.f32(float %a) -; FTZ: call float @llvm.nvvm.sqrt.rn.f +; CHECK: call float @llvm.nvvm.sqrt.rn.f %ret = call float @llvm.nvvm.sqrt.rn.f(float %a) ret float %ret } ; CHECK-LABEL: @test_sqrt_rn_f_ftz define float @test_sqrt_rn_f_ftz(float %a) #0 { -; NOFTZ: call float @llvm.nvvm.sqrt.rn.f -; FTZ: call float @llvm.sqrt.f32(float %a) +; CHECK: call float @llvm.nvvm.sqrt.rn.ftz.f(float %a) %ret = call float @llvm.nvvm.sqrt.rn.ftz.f(float %a) ret float %ret }