Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ def __nvvm_ff2bf16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)
def __nvvm_ff2bf16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2bf16x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2bf16x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2bf16x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2bf16x2_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2bf16x2_rs :
NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
Expand All @@ -596,6 +600,10 @@ def __nvvm_ff2f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)"
def __nvvm_ff2f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rz_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX70>;
def __nvvm_ff2f16x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2f16x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2f16x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2f16x2_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float)", SM_80, PTX81>;
def __nvvm_ff2f16x2_rs :
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(float, float, uint32_t)",
SM<"100a", [SM_103a]>, PTX87>;
Expand All @@ -613,6 +621,19 @@ def __nvvm_f2bf16_rn : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rn_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rz : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rz_relu : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX70>;
def __nvvm_f2bf16_rn_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
def __nvvm_f2bf16_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
def __nvvm_f2bf16_rz_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;
def __nvvm_f2bf16_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"__bf16(float)", SM_80, PTX81>;

def __nvvm_f2f16_rn : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
def __nvvm_f2f16_rn_relu : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
def __nvvm_f2f16_rz : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
def __nvvm_f2f16_rz_relu : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX70>;
def __nvvm_f2f16_rn_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
def __nvvm_f2f16_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
def __nvvm_f2f16_rz_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;
def __nvvm_f2f16_rz_relu_satfinite : NVPTXBuiltinSMAndPTX<"__fp16(float)", SM_80, PTX81>;

def __nvvm_f2tf32_rna : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX70>;
def __nvvm_f2tf32_rna_satfinite : NVPTXBuiltinSMAndPTX<"int32_t(float)", SM_80, PTX81>;
Expand Down
49 changes: 49 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,16 @@ __device__ void nvvm_cvt_sm80() {
__nvvm_ff2bf16x2_rz(1, 1);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu(1, 1);
#if PTX >= 81
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_relu_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu_satfinite(1, 1);
#endif

// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rn(1, 1);
Expand All @@ -1016,6 +1026,16 @@ __device__ void nvvm_cvt_sm80() {
__nvvm_ff2f16x2_rz(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu(1, 1);
#if PTX >= 81
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rn_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rn_relu_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_satfinite(1, 1);
// CHECK_PTX81_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu_satfinite(1, 1);
#endif

// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
__nvvm_f2bf16_rn(1);
Expand All @@ -1025,6 +1045,35 @@ __device__ void nvvm_cvt_sm80() {
__nvvm_f2bf16_rz(1);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
__nvvm_f2bf16_rz_relu(1);
#if PTX >= 81
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rn.satfinite(float 1.000000e+00)
__nvvm_f2bf16_rn_satfinite(1);
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu.satfinite(float 1.000000e+00)
__nvvm_f2bf16_rn_relu_satfinite(1);
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rz.satfinite(float 1.000000e+00)
__nvvm_f2bf16_rz_satfinite(1);
// CHECK_PTX81_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu.satfinite(float 1.000000e+00)
__nvvm_f2bf16_rz_relu_satfinite(1);
#endif

// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rn(float 1.000000e+00)
__nvvm_f2f16_rn(1);
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rn.relu(float 1.000000e+00)
__nvvm_f2f16_rn_relu(1);
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rz(float 1.000000e+00)
__nvvm_f2f16_rz(1);
// CHECK_PTX70_SM80: call half @llvm.nvvm.f2f16.rz.relu(float 1.000000e+00)
__nvvm_f2f16_rz_relu(1);
#if PTX >= 81
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rn.satfinite(float 1.000000e+00)
__nvvm_f2f16_rn_satfinite(1);
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rn.relu.satfinite(float 1.000000e+00)
__nvvm_f2f16_rn_relu_satfinite(1);
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rz.satfinite(float 1.000000e+00)
__nvvm_f2f16_rz_satfinite(1);
// CHECK_PTX81_SM80: call half @llvm.nvvm.f2f16.rz.relu.satfinite(float 1.000000e+00)
__nvvm_f2f16_rz_relu_satfinite(1);
#endif

// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
__nvvm_f2tf32_rna(1);
Expand Down
21 changes: 13 additions & 8 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1566,14 +1566,19 @@ let TargetPrefix = "nvvm" in {

foreach rnd = ["rn", "rz"] in {
foreach relu = ["", "_relu"] in {
def int_nvvm_ff2bf16x2_ # rnd # relu : NVVMBuiltin,
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty]>;

def int_nvvm_ff2f16x2_ # rnd # relu : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty]>;

def int_nvvm_f2bf16_ # rnd # relu : NVVMBuiltin,
PureIntrinsic<[llvm_bfloat_ty], [llvm_float_ty]>;
foreach satfinite = ["", "_satfinite"] in {
def int_nvvm_ff2bf16x2_ # rnd # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty]>;

def int_nvvm_ff2f16x2_ # rnd # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty]>;

def int_nvvm_f2bf16_ # rnd # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_bfloat_ty], [llvm_float_ty]>;

def int_nvvm_f2f16_ # rnd # relu # satfinite : NVVMBuiltin,
PureIntrinsic<[llvm_half_ty], [llvm_float_ty]>;
}
}
}

Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,15 @@ let hasSideEffects = false in {
defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>;
defm CVT_f32 : CVT_FROM_ALL<"f32", B32>;
defm CVT_f64 : CVT_FROM_ALL<"f64", B64>;

multiclass CVT_FROM_FLOAT_SATFINITE<string ToName, RegisterClass RC> {
def _f32_sf :
BasicFlagsNVPTXInst<(outs RC:$dst),
(ins B32:$src), (ins CvtMode:$mode),
"cvt${mode:base}${mode:relu}.satfinite." # ToName # ".f32">;
}
defm CVT_bf16 : CVT_FROM_FLOAT_SATFINITE<"bf16", B16>;
defm CVT_f16 : CVT_FROM_FLOAT_SATFINITE<"f16", B16>;

// These cvts are different from those above: The source and dest registers
// are of the same type.
Expand All @@ -611,6 +620,11 @@ let hasSideEffects = false in {
(ins B32:$src1, B32:$src2), (ins CvtMode:$mode),
"cvt${mode:base}${mode:relu}." # FromName # ".f32">,
Requires<[hasPTX<70>, hasSM<80>]>;

def _f32_sf :
BasicFlagsNVPTXInst<(outs RC:$dst),
(ins B32:$src1, B32:$src2), (ins CvtMode:$mode),
"cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">;
}

defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>;
Expand Down
30 changes: 29 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,12 @@ def : Pat<(int_nvvm_ff2bf16x2_rn f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C
def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>;

let Predicates = [hasPTX<81>, hasSM<80>] in {
def : Pat<(int_nvvm_ff2bf16x2_rn_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff2bf16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ_RELU)>;
}
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c),
(CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>;
Expand All @@ -1933,6 +1938,12 @@ def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, Cvt
def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>;
let Predicates = [hasPTX<81>, hasSM<80>] in {
def : Pat<(int_nvvm_ff2f16x2_rn_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff2f16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ_RELU)>;
}

let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c),
Expand All @@ -1948,6 +1959,23 @@ def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>;
def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), (CVT_bf16_f32 $a, CvtRZ_RELU)>;
let Predicates = [hasPTX<81>, hasSM<80>] in {
def : Pat<(int_nvvm_f2bf16_rz_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ)>;
def : Pat<(int_nvvm_f2bf16_rz_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ_RELU)>;
def : Pat<(int_nvvm_f2bf16_rn_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN_RELU)>;
}

def : Pat<(int_nvvm_f2f16_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
def : Pat<(int_nvvm_f2f16_rn_relu f32:$a), (CVT_f16_f32 $a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2f16_rz f32:$a), (CVT_f16_f32 $a, CvtRZ)>;
def : Pat<(int_nvvm_f2f16_rz_relu f32:$a), (CVT_f16_f32 $a, CvtRZ_RELU)>;
let Predicates = [hasPTX<81>, hasSM<80>] in {
def : Pat<(int_nvvm_f2f16_rz_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ)>;
def : Pat<(int_nvvm_f2f16_rz_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ_RELU)>;
def : Pat<(int_nvvm_f2f16_rn_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN)>;
def : Pat<(int_nvvm_f2f16_rn_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN_RELU)>;
}

def : Pat<(int_nvvm_lohi_i2d i32:$a, i32:$b), (V2I32toI64 $a, $b)>;
def : Pat<(int_nvvm_d2i_lo f64:$a), (I64toI32L $a)>;
Expand Down
Loading
Loading