diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index ad448766e665f..052c20455b373 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -378,16 +378,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16) def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">; @@ -446,6 +454,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">; // Add +def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">; @@ -460,6 +473,20 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; +// Sub + +def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + +// Mul + +def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + // Convert def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index c0ed799970122..d705bcbe208d1 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -31,6 +31,9 @@ // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s +// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\ +// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ +// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s @@ -1470,3 +1473,64 @@ __device__ void nvvm_min_max_sm86() { #endif // CHECK: ret void } + +#define F16 (__fp16)0.1f +#define F16_2 (__fp16)0.2f +#define F16X2 {(__fp16)0.1f, (__fp16)0.1f} +#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f} + +// CHECK-LABEL: nvvm_add_sub_mul_f16_sat +__device__ void nvvm_add_sub_mul_f16_sat() { + // CHECK: call half @llvm.nvvm.add.rn.sat.f16 + __nvvm_add_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 + __nvvm_add_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2 + __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2 + __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: call half @llvm.nvvm.sub.rn.sat.f16 + __nvvm_sub_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16 + __nvvm_sub_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2 + __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2 + __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: call half @llvm.nvvm.mul.rn.sat.f16 + __nvvm_mul_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16 + __nvvm_mul_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2 + __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2 + __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: ret void +} + +// CHECK-LABEL: nvvm_fma_oob +__device__ void nvvm_fma_oob() { +#if __CUDA_ARCH__ >= 900 && (PTX >= 81) + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16 + __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16 + __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2 + __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2 + __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2); + + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16 + __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16 + __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2 + __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2 + __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2); +#endif + // CHECK: ret void +} diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 1b485dc8ccd1e..65303ecb48dd8 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1365,16 +1365,38 @@ let TargetPrefix = "nvvm" in { def int_nvvm_fma_rn # ftz # variant # _f16x2 : PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin, - PureIntrinsic<[llvm_bfloat_ty], - [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2bf16_ty], - [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; } // ftz } // variant + + foreach relu = ["", "_relu"] in { + def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin, + PureIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; + + def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + } // relu + + // oob (out-of-bounds) - clamps the result to 0 if either of the operands is + // OOB NaN value. + foreach relu = ["", "_relu"] in { + def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], + [llvm_half_ty, llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], + [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; + + def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin, + PureIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; + + def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + } // relu foreach rnd = ["rn", "rz", "rm", "rp"] in { foreach ftz = ["", "_ftz"] in @@ -1442,6 +1464,15 @@ let TargetPrefix = "nvvm" in { // // Add // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + + } // ftz + let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in { foreach rnd = ["rn", "rz", "rm", "rp"] in { foreach ftz = ["", "_ftz"] in @@ -1452,6 +1483,28 @@ let TargetPrefix = "nvvm" in { DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } } + + // + // Sub + // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + } // ftz + + // + // Mul + // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + } // ftz // // Dot Product diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 58b7ddd0381e5..1e40242213b99 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1098,16 +1098,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) { return StringSwitch(Name) .Case("bf16", Intrinsic::nvvm_fma_rn_bf16) .Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) - .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) - .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) - .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) - .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) - .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) - .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) .Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) .Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) - .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) - .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) .Default(Intrinsic::not_intrinsic); if (Name.consume_front("fmax.")) diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ea69a54e6db37..57fdd4dc3c388 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1656,18 +1656,18 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16, + [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16, - [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16, + [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32, [hasPTX<42>, hasSM<53>]>, @@ -1681,10 +1681,19 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, B32, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32, - [hasPTX<70>, hasSM<80>]> + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32, + [hasPTX<81>, hasSM<90>]>, ] in { def P.Variant : F_MATH_3; +def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>; +def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>; +def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>; + def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>; @@ -1806,6 +1820,24 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d> def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; +// +// Sub +// + +def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>; +def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>; +def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>; +def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>; + +// +// Mul +// + +def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>; +def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>; + // // BFIND // diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 64593e6439184..29a81c04395e3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC, return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16x2: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16x2: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fmax_d: return {Intrinsic::maxnum, FTZ_Any}; case Intrinsic::nvvm_fmax_f: diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll new file mode 100644 index 0000000000000..a623d6e5351ab --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @add_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_sat_f16_param_1]; +; CHECK-NEXT: add.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @add_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll new file mode 100644 index 0000000000000..68caac8c36e31 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @mul_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @mul_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll new file mode 100644 index 0000000000000..2c02f6aa3160e --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @sub_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @sub_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll new file mode 100644 index 0000000000000..2553c5f298b17 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll @@ -0,0 +1,131 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | FileCheck %s +; RUN: %if ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | %ptxas-verify -arch=sm_90 %} + +define half @fma_oob_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.f16(half %a, half %b, half %c) + ret half %1 +} + +define half @fma_oob_relu_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_relu_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %a, half %b, half %c) + ret half %1 +} + +define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_relu_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define bfloat @fma_oob_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define bfloat @fma_oob_relu_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_relu_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +} + +define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_relu_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +}