Skip to content

Commit

Permalink
[NVPTX] remove incorrect NVPTX intrinsic transformations (llvm#76870)
Browse files Browse the repository at this point in the history
`nvvm_fabs_f`
`nvvm_fabs_ftz_f`

Unfortunately, llvm fabs is not equivalent to these intrinsics since
llvm fabs is defined to only set the sign bit to zero while these can
also flush subnormal inputs and modify NaNs.

`nvvm_round_d`
`nvvm_round_f`
`nvvm_round_ftz_f`

llvm.nvvm.round uses RNI, while llvm.round codegens to RZI. LLVM defines
llvm.round to use the same rounding as libm
`round[f]()`, which is not necessary the same as how we define
llvm.nvvm.round.

`nvvm_sqrt_rn_f`
`nvvm_sqrt_rn_ftz_f`

sqrt may be lowered to a less precise version of sqrt, such as
sqrt.approx in NVPTX depending on factors such as the value of
-nvptx-prec-sqrtf32. These intrinsics should always become the
corresponding NVPTX instructions.

`nvvm_add_rn_d`
`nvvm_add_rn_f`
`nvvm_add_rn_ftz_f`
`nvvm_mul_rn_d`
`nvvm_mul_rn_f`
`nvvm_mul_rn_ftz_f`

These nvvm intrinsics have an explicitly specified rounding mode (.rn).
They should always be lowered to a PTX instruction with the same
explicit rounding mode. Converting to fmul and fadd instructions result
in the PTX instructions without rounding modes specified. This can cause
issue because:

> An add [or mul] instruction with no rounding modifier defaults to
round-to-nearest-even and may be optimized aggressively by the code
optimizer. In particular, mul/add sequences with no rounding modifiers
may be optimized to use fused-multiply-add instructions on the target
device.

`nvvm_div_rn_f`
`nvvm_div_rn_ftz_f`
`nvvm_rcp_rn_f`
`nvvm_rcp_rn_ftz_f`

fdiv may be lowered to a less precise version of div, such as div.full
in NVPTX depending on factors such as the value of -nvptx-prec-divf32.
These intrinsics should always become the corresponding NVPTX
instructions.
  • Loading branch information
AlexMaclean authored and justinfargnoli committed Jan 28, 2024
1 parent b0794fa commit f5e4a52
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 65 deletions.
34 changes: 0 additions & 34 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
Expand Up @@ -180,10 +180,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
return {Intrinsic::ceil, FTZ_MustBeOn};
case Intrinsic::nvvm_fabs_d:
return {Intrinsic::fabs, FTZ_Any};
case Intrinsic::nvvm_fabs_f:
return {Intrinsic::fabs, FTZ_MustBeOff};
case Intrinsic::nvvm_fabs_ftz_f:
return {Intrinsic::fabs, FTZ_MustBeOn};
case Intrinsic::nvvm_floor_d:
return {Intrinsic::floor, FTZ_Any};
case Intrinsic::nvvm_floor_f:
Expand Down Expand Up @@ -264,12 +260,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
return {Intrinsic::minimum, FTZ_MustBeOff, true};
case Intrinsic::nvvm_fmin_ftz_nan_f16x2:
return {Intrinsic::minimum, FTZ_MustBeOn, true};
case Intrinsic::nvvm_round_d:
return {Intrinsic::round, FTZ_Any};
case Intrinsic::nvvm_round_f:
return {Intrinsic::round, FTZ_MustBeOff};
case Intrinsic::nvvm_round_ftz_f:
return {Intrinsic::round, FTZ_MustBeOn};
case Intrinsic::nvvm_sqrt_rn_d:
return {Intrinsic::sqrt, FTZ_Any};
case Intrinsic::nvvm_sqrt_f:
Expand All @@ -278,10 +268,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
// the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are
// the versions with explicit ftz-ness.
return {Intrinsic::sqrt, FTZ_Any};
case Intrinsic::nvvm_sqrt_rn_f:
return {Intrinsic::sqrt, FTZ_MustBeOff};
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return {Intrinsic::sqrt, FTZ_MustBeOn};
case Intrinsic::nvvm_trunc_d:
return {Intrinsic::trunc, FTZ_Any};
case Intrinsic::nvvm_trunc_f:
Expand Down Expand Up @@ -316,24 +302,8 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
return {Instruction::UIToFP};

// NVVM intrinsics that map to LLVM binary ops.
case Intrinsic::nvvm_add_rn_d:
return {Instruction::FAdd, FTZ_Any};
case Intrinsic::nvvm_add_rn_f:
return {Instruction::FAdd, FTZ_MustBeOff};
case Intrinsic::nvvm_add_rn_ftz_f:
return {Instruction::FAdd, FTZ_MustBeOn};
case Intrinsic::nvvm_mul_rn_d:
return {Instruction::FMul, FTZ_Any};
case Intrinsic::nvvm_mul_rn_f:
return {Instruction::FMul, FTZ_MustBeOff};
case Intrinsic::nvvm_mul_rn_ftz_f:
return {Instruction::FMul, FTZ_MustBeOn};
case Intrinsic::nvvm_div_rn_d:
return {Instruction::FDiv, FTZ_Any};
case Intrinsic::nvvm_div_rn_f:
return {Instruction::FDiv, FTZ_MustBeOff};
case Intrinsic::nvvm_div_rn_ftz_f:
return {Instruction::FDiv, FTZ_MustBeOn};

// The remainder of cases are NVVM intrinsics that map to LLVM idioms, but
// need special handling.
Expand All @@ -342,10 +312,6 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
// as well.
case Intrinsic::nvvm_rcp_rn_d:
return {SPC_Reciprocal, FTZ_Any};
case Intrinsic::nvvm_rcp_rn_f:
return {SPC_Reciprocal, FTZ_MustBeOff};
case Intrinsic::nvvm_rcp_rn_ftz_f:
return {SPC_Reciprocal, FTZ_MustBeOn};

// We do not currently simplify intrinsics that give an approximate
// answer. These include:
Expand Down
48 changes: 17 additions & 31 deletions llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll
Expand Up @@ -49,15 +49,13 @@ define double @fabs_double(double %a) #0 {
}
; CHECK-LABEL: @fabs_float
define float @fabs_float(float %a) #0 {
; NOFTZ: call float @llvm.fabs.f32
; FTZ: call float @llvm.nvvm.fabs.f
; CHECK: call float @llvm.nvvm.fabs.f
%ret = call float @llvm.nvvm.fabs.f(float %a)
ret float %ret
}
; CHECK-LABEL: @fabs_float_ftz
define float @fabs_float_ftz(float %a) #0 {
; NOFTZ: call float @llvm.nvvm.fabs.ftz.f
; FTZ: call float @llvm.fabs.f32
; CHECK: call float @llvm.nvvm.fabs.ftz.f
%ret = call float @llvm.nvvm.fabs.ftz.f(float %a)
ret float %ret
}
Expand Down Expand Up @@ -148,21 +146,19 @@ define float @fmin_float_ftz(float %a, float %b) #0 {

; CHECK-LABEL: @round_double
define double @round_double(double %a) #0 {
; CHECK: call double @llvm.round.f64
; CHECK: call double @llvm.nvvm.round.d
%ret = call double @llvm.nvvm.round.d(double %a)
ret double %ret
}
; CHECK-LABEL: @round_float
define float @round_float(float %a) #0 {
; NOFTZ: call float @llvm.round.f32
; FTZ: call float @llvm.nvvm.round.f
; CHECK: call float @llvm.nvvm.round.f
%ret = call float @llvm.nvvm.round.f(float %a)
ret float %ret
}
; CHECK-LABEL: @round_float_ftz
define float @round_float_ftz(float %a) #0 {
; NOFTZ: call float @llvm.nvvm.round.ftz.f
; FTZ: call float @llvm.round.f32
; CHECK: call float @llvm.nvvm.round.ftz.f
%ret = call float @llvm.nvvm.round.ftz.f(float %a)
ret float %ret
}
Expand Down Expand Up @@ -292,42 +288,38 @@ define float @test_ull2f(i64 %a) #0 {

; CHECK-LABEL: @test_add_rn_d
define double @test_add_rn_d(double %a, double %b) #0 {
; CHECK: fadd
; CHECK: call double @llvm.nvvm.add.rn.d
%ret = call double @llvm.nvvm.add.rn.d(double %a, double %b)
ret double %ret
}
; CHECK-LABEL: @test_add_rn_f
define float @test_add_rn_f(float %a, float %b) #0 {
; NOFTZ: fadd
; FTZ: call float @llvm.nvvm.add.rn.f
; CHECK: call float @llvm.nvvm.add.rn.f
%ret = call float @llvm.nvvm.add.rn.f(float %a, float %b)
ret float %ret
}
; CHECK-LABEL: @test_add_rn_f_ftz
define float @test_add_rn_f_ftz(float %a, float %b) #0 {
; NOFTZ: call float @llvm.nvvm.add.rn.f
; FTZ: fadd
; CHECK: call float @llvm.nvvm.add.rn.ftz.f(float %a, float %b)
%ret = call float @llvm.nvvm.add.rn.ftz.f(float %a, float %b)
ret float %ret
}

; CHECK-LABEL: @test_mul_rn_d
define double @test_mul_rn_d(double %a, double %b) #0 {
; CHECK: fmul
; CHECK: call double @llvm.nvvm.mul.rn.d
%ret = call double @llvm.nvvm.mul.rn.d(double %a, double %b)
ret double %ret
}
; CHECK-LABEL: @test_mul_rn_f
define float @test_mul_rn_f(float %a, float %b) #0 {
; NOFTZ: fmul
; FTZ: call float @llvm.nvvm.mul.rn.f
; CHECK: call float @llvm.nvvm.mul.rn.f
%ret = call float @llvm.nvvm.mul.rn.f(float %a, float %b)
ret float %ret
}
; CHECK-LABEL: @test_mul_rn_f_ftz
define float @test_mul_rn_f_ftz(float %a, float %b) #0 {
; NOFTZ: call float @llvm.nvvm.mul.rn.f
; FTZ: fmul
; CHECK: call float @llvm.nvvm.mul.rn.ftz.f(float %a, float %b)
%ret = call float @llvm.nvvm.mul.rn.ftz.f(float %a, float %b)
ret float %ret
}
Expand All @@ -340,15 +332,13 @@ define double @test_div_rn_d(double %a, double %b) #0 {
}
; CHECK-LABEL: @test_div_rn_f
define float @test_div_rn_f(float %a, float %b) #0 {
; NOFTZ: fdiv
; FTZ: call float @llvm.nvvm.div.rn.f
; CHECK: call float @llvm.nvvm.div.rn.f
%ret = call float @llvm.nvvm.div.rn.f(float %a, float %b)
ret float %ret
}
; CHECK-LABEL: @test_div_rn_f_ftz
define float @test_div_rn_f_ftz(float %a, float %b) #0 {
; NOFTZ: call float @llvm.nvvm.div.rn.f
; FTZ: fdiv
; CHECK: call float @llvm.nvvm.div.rn.ftz.f(float %a, float %b)
%ret = call float @llvm.nvvm.div.rn.ftz.f(float %a, float %b)
ret float %ret
}
Expand All @@ -357,15 +347,13 @@ define float @test_div_rn_f_ftz(float %a, float %b) #0 {

; CHECK-LABEL: @test_rcp_rn_f
define float @test_rcp_rn_f(float %a) #0 {
; NOFTZ: fdiv float 1.0{{.*}} %a
; FTZ: call float @llvm.nvvm.rcp.rn.f
; CHECK: call float @llvm.nvvm.rcp.rn.f
%ret = call float @llvm.nvvm.rcp.rn.f(float %a)
ret float %ret
}
; CHECK-LABEL: @test_rcp_rn_f_ftz
define float @test_rcp_rn_f_ftz(float %a) #0 {
; NOFTZ: call float @llvm.nvvm.rcp.rn.f
; FTZ: fdiv float 1.0{{.*}} %a
; CHECK: call float @llvm.nvvm.rcp.rn.ftz.f(float %a)
%ret = call float @llvm.nvvm.rcp.rn.ftz.f(float %a)
ret float %ret
}
Expand All @@ -385,15 +373,13 @@ define float @test_sqrt_f(float %a) #0 {
}
; CHECK-LABEL: @test_sqrt_rn_f
define float @test_sqrt_rn_f(float %a) #0 {
; NOFTZ: call float @llvm.sqrt.f32(float %a)
; FTZ: call float @llvm.nvvm.sqrt.rn.f
; CHECK: call float @llvm.nvvm.sqrt.rn.f
%ret = call float @llvm.nvvm.sqrt.rn.f(float %a)
ret float %ret
}
; CHECK-LABEL: @test_sqrt_rn_f_ftz
define float @test_sqrt_rn_f_ftz(float %a) #0 {
; NOFTZ: call float @llvm.nvvm.sqrt.rn.f
; FTZ: call float @llvm.sqrt.f32(float %a)
; CHECK: call float @llvm.nvvm.sqrt.rn.ftz.f(float %a)
%ret = call float @llvm.nvvm.sqrt.rn.ftz.f(float %a)
ret float %ret
}
Expand Down

0 comments on commit f5e4a52

Please sign in to comment.