From 5f3d2b1f627ba2e9da63240297c3d2080e6935f9 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 25 Nov 2025 06:37:54 +0000 Subject: [PATCH 1/9] [clang][NVPTX] Add missing half-precision add/sub/fma intrinsics This change adds the following missing half-precision add/sub/fma intrinsics for the NVPTX target: - `llvm.nvvm.add.rn{.ftz}.sat.f16` - `llvm.nvvm.add.rn{.ftz}.sat.f16x2` - `llvm.nvvm.sub.rn{.ftz}.sat.f16` - `llvm.nvvm.sub.rn{.ftz}.sat.f16x2` - `llvm.nvvm.fma.rn.oob.*` This also removes some incorrect `bf16` fma intrinsics with no valid lowering. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions --- clang/include/clang/Basic/BuiltinsNVPTX.td | 20 +++ clang/test/CodeGen/builtins-nvptx.c | 55 ++++++++ llvm/include/llvm/IR/IntrinsicsNVVM.td | 58 ++++++-- llvm/lib/IR/AutoUpgrade.cpp | 8 -- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 ++++-- .../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 4 - llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 63 +++++++++ llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 63 +++++++++ llvm/test/CodeGen/NVPTX/fma-oob.ll | 131 ++++++++++++++++++ 9 files changed, 414 insertions(+), 29 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/f16-add-sat.ll create mode 100644 llvm/test/CodeGen/NVPTX/f16-sub-sat.ll create mode 100644 llvm/test/CodeGen/NVPTX/fma-oob.ll diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index 6fbd2222ab289..a3263f80a76e1 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -378,16 +378,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16) def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">; @@ -446,6 +454,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">; // Add +def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">; @@ -460,6 +473,13 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; +// Sub + +def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + // Convert def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 75f2588f4837b..594cdd4da9ef7 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -31,6 +31,9 @@ // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s +// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\ +// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ +// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s @@ -1519,3 +1522,55 @@ __device__ void nvvm_min_max_sm86() { #endif // CHECK: ret void } + +#define F16 (__fp16)0.1f +#define F16_2 (__fp16)0.2f +#define F16X2 {(__fp16)0.1f, (__fp16)0.1f} +#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f} + +// CHECK-LABEL: nvvm_add_sub_f16_sat +__device__ void nvvm_add_sub_f16_sat() { + // CHECK: call half @llvm.nvvm.add.rn.sat.f16 + __nvvm_add_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 + __nvvm_add_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2 + __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2 + __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: call half @llvm.nvvm.sub.rn.sat.f16 + __nvvm_sub_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16 + __nvvm_sub_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2 + __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2 + __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: ret void +} + +// CHECK-LABEL: nvvm_fma_oob +__device__ void nvvm_fma_oob() { +#if __CUDA_ARCH__ >= 900 && (PTX >= 81) + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16 + __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16 + __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2 + __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2 + __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2); + + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16 + __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16 + __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2 + __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2 + __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2); +#endif + // CHECK: ret void +} diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index c71f37f671539..e40a5928acff7 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1490,16 +1490,37 @@ let TargetPrefix = "nvvm" in { def int_nvvm_fma_rn # ftz # variant # _f16x2 : PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin, - PureIntrinsic<[llvm_bfloat_ty], - [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2bf16_ty], - [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; } // ftz } // variant + + foreach relu = ["", "_relu"] in { + def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin, + PureIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; + + def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + } // relu + + // oob + foreach relu = ["", "_relu"] in { + def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], + [llvm_half_ty, llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], + [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; + + def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin, + PureIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; + + def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + } // relu foreach rnd = ["rn", "rz", "rm", "rp"] in { foreach ftz = ["", "_ftz"] in @@ -1567,6 +1588,15 @@ let TargetPrefix = "nvvm" in { // // Add // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + + } // ftz + let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in { foreach rnd = ["rn", "rz", "rm", "rp"] in { foreach ftz = ["", "_ftz"] in @@ -1577,6 +1607,18 @@ let TargetPrefix = "nvvm" in { DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } } + + + // + // Sub + // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + } // ftz // // Dot Product diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 487db134b0df3..e0cd82b54ef23 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1106,16 +1106,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) { return StringSwitch(Name) .Case("bf16", Intrinsic::nvvm_fma_rn_bf16) .Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) - .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) - .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) - .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) - .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) - .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) - .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) .Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) .Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) - .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) - .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) .Default(Intrinsic::not_intrinsic); if (Name.consume_front("fmax.")) diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index d18c7e20df038..e6e8126d0fee8 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1691,18 +1691,18 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16, + [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16, - [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16, + [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32, [hasPTX<42>, hasSM<53>]>, @@ -1716,10 +1716,19 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, B32, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32, - [hasPTX<70>, hasSM<80>]> + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32, + [hasPTX<81>, hasSM<90>]>, + FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32, + [hasPTX<81>, hasSM<90>]>, ] in { def P.Variant : F_MATH_3; +def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>; +def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>; +def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>; + def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>; @@ -1841,6 +1855,15 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d> def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; +// +// Sub +// + +def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>; +def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>; +def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>; +def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>; + // // BFIND // diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 5d5553c573b0f..8aad17fe12709 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC, return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16x2: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16x2: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fmax_d: return {Intrinsic::maxnum, FTZ_Any}; case Intrinsic::nvvm_fmax_f: diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll new file mode 100644 index 0000000000000..bf2f938d4d36c --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @add_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_sat_f16_param_1]; +; CHECK-NEXT: add.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @add_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll new file mode 100644 index 0000000000000..25f7b63b13db5 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @sub_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @sub_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll new file mode 100644 index 0000000000000..2553c5f298b17 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll @@ -0,0 +1,131 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | FileCheck %s +; RUN: %if ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | %ptxas-verify -arch=sm_90 %} + +define half @fma_oob_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.f16(half %a, half %b, half %c) + ret half %1 +} + +define half @fma_oob_relu_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_relu_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %a, half %b, half %c) + ret half %1 +} + +define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_relu_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define bfloat @fma_oob_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define bfloat @fma_oob_relu_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_relu_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +} + +define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_relu_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +} From 528ddad0164494f823b8e14340ec36aca4599d83 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 27 Nov 2025 08:35:37 +0000 Subject: [PATCH 2/9] address comments and add mul --- clang/include/clang/Basic/BuiltinsNVPTX.td | 7 +++ clang/test/CodeGen/builtins-nvptx.c | 13 ++++- llvm/include/llvm/IR/IntrinsicsNVVM.td | 15 +++++- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 9 ++++ llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 63 ++++++++++++++++++++++ 5 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/f16-mul-sat.ll diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index a3263f80a76e1..251d8caac6390 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -480,6 +480,13 @@ def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", S def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +// Mul + +def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + // Convert def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 594cdd4da9ef7..199d79f25bcb0 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1528,8 +1528,8 @@ __device__ void nvvm_min_max_sm86() { #define F16X2 {(__fp16)0.1f, (__fp16)0.1f} #define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f} -// CHECK-LABEL: nvvm_add_sub_f16_sat -__device__ void nvvm_add_sub_f16_sat() { +// CHECK-LABEL: nvvm_add_sub_mul_f16_sat +__device__ void nvvm_add_sub_mul_f16_sat() { // CHECK: call half @llvm.nvvm.add.rn.sat.f16 __nvvm_add_rn_sat_f16(F16, F16_2); // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 @@ -1547,6 +1547,15 @@ __device__ void nvvm_add_sub_f16_sat() { __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2); // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2 __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2); + + // CHECK: call half @llvm.nvvm.mul.rn.sat.f16 + __nvvm_mul_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16 + __nvvm_mul_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2 + __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2 + __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2); // CHECK: ret void } diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index e40a5928acff7..b1c38f34b1321 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1503,7 +1503,8 @@ let TargetPrefix = "nvvm" in { [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; } // relu - // oob + // oob (out-of-bounds) - clamps the result to 0 if either of the operands is + // OOB NaN value. foreach relu = ["", "_relu"] in { def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin, PureIntrinsic<[llvm_half_ty], @@ -1608,7 +1609,6 @@ let TargetPrefix = "nvvm" in { } } - // // Sub // @@ -1620,6 +1620,17 @@ let TargetPrefix = "nvvm" in { PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; } // ftz + // + // Mul + // + foreach ftz = ["", "_ftz"] in { + def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin, + PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + } // ftz + // // Dot Product // diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index e6e8126d0fee8..440224bdd1454 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1864,6 +1864,15 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>; def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>; +// +// Mul +// + +def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>; +def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>; + // // BFIND // diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll new file mode 100644 index 0000000000000..77c498b6c3145 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @mul_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @mul_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} From 22fb84aa1c35c9d3f88525dcbb7c7252ed69272b Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Thu, 27 Nov 2025 08:39:07 +0000 Subject: [PATCH 3/9] fix test formatting --- llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++-- llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++-- llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll index bf2f938d4d36c..a623d6e5351ab 100644 --- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll @@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } @@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll index 77c498b6c3145..68caac8c36e31 100644 --- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll @@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } @@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll index 25f7b63b13db5..2c02f6aa3160e 100644 --- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -28,7 +28,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } @@ -58,6 +58,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2( <2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } From 210c875bcd1a768b0be2bcdb46a8069840b1d810 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 2 Dec 2025 12:06:31 +0000 Subject: [PATCH 4/9] fold add with fneg to sub --- clang/include/clang/Basic/BuiltinsNVPTX.td | 7 ---- clang/test/CodeGen/builtins-nvptx.c | 9 ----- llvm/include/llvm/IR/IntrinsicsNVVM.td | 11 ------ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 37 ++++++++++++++++++++- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 33 +++++++++++++++--- llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 22 +++++++----- 6 files changed, 79 insertions(+), 40 deletions(-) diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index 251d8caac6390..5ab79b326ee0f 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -473,13 +473,6 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; -// Sub - -def __nvvm_sub_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; -def __nvvm_sub_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; -def __nvvm_sub_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; -def __nvvm_sub_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; - // Mul def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 199d79f25bcb0..603f25577eb84 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1539,15 +1539,6 @@ __device__ void nvvm_add_sub_mul_f16_sat() { // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2 __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2); - // CHECK: call half @llvm.nvvm.sub.rn.sat.f16 - __nvvm_sub_rn_sat_f16(F16, F16_2); - // CHECK: call half @llvm.nvvm.sub.rn.ftz.sat.f16 - __nvvm_sub_rn_ftz_sat_f16(F16, F16_2); - // CHECK: call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2 - __nvvm_sub_rn_sat_f16x2(F16X2, F16X2_2); - // CHECK: call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2 - __nvvm_sub_rn_ftz_sat_f16x2(F16X2, F16X2_2); - // CHECK: call half @llvm.nvvm.mul.rn.sat.f16 __nvvm_mul_rn_sat_f16(F16, F16_2); // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16 diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index b1c38f34b1321..d24511f371e02 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1608,17 +1608,6 @@ let TargetPrefix = "nvvm" in { DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } } - - // - // Sub - // - foreach ftz = ["", "_ftz"] in { - def int_nvvm_sub_rn # ftz # _sat_f16 : NVVMBuiltin, - PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; - - def int_nvvm_sub_rn # ftz # _sat_f16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; - } // ftz // // Mul diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 8b72b1e1f3a52..4636ef5fc88b5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -873,7 +873,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); + ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND, + ISD::INTRINSIC_WO_CHAIN}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -6504,6 +6505,38 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain, } } +// Combine add.sat(a, fneg(b)) -> sub.sat(a, b) +static SDValue combineAddSatWithNeg(SDNode *N, SelectionDAG &DAG, + unsigned SubOpc) { + SDValue Op2 = N->getOperand(2); + + if (Op2.getOpcode() != ISD::FNEG) + return SDValue(); + + SDLoc DL(N); + return DAG.getNode(SubOpc, DL, N->getValueType(0), N->getOperand(1), + Op2.getOperand(0)); +} + +static SDValue combineIntrinsicWOChain(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &STI) { + unsigned IntID = N->getConstantOperandVal(0); + + switch (IntID) { + case Intrinsic::nvvm_add_rn_sat_f16: + return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16); + case Intrinsic::nvvm_add_rn_ftz_sat_f16: + return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16); + case Intrinsic::nvvm_add_rn_sat_f16x2: + return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2); + case Intrinsic::nvvm_add_rn_ftz_sat_f16x2: + return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2); + default: + return SDValue(); + } +} + static SDValue combineProxyReg(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { @@ -6570,6 +6603,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return combineSTORE(N, DCI, STI); case ISD::VSELECT: return PerformVSELECTCombine(N, DCI); + case ISD::INTRINSIC_WO_CHAIN: + return combineIntrinsicWOChain(N, DCI, STI); } return SDValue(); } diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 440224bdd1454..f5ca88c9cc717 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1859,10 +1859,34 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d> // Sub // -def INT_NVVM_SUB_RN_SAT_F16 : F_MATH_2<"sub.rn.sat.f16", B16, B16, B16, int_nvvm_sub_rn_sat_f16>; -def INT_NVVM_SUB_RN_FTZ_SAT_F16 : F_MATH_2<"sub.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_sub_rn_ftz_sat_f16>; -def INT_NVVM_SUB_RN_SAT_F16X2 : F_MATH_2<"sub.rn.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_sat_f16x2>; -def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : F_MATH_2<"sub.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f16x2>; +def SUB_RN_SAT_F16_NODE : SDNode<"NVPTXISD::SUB_RN_SAT_F16", SDTFPBinOp>; +def SUB_RN_FTZ_SAT_F16_NODE : + SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16", SDTFPBinOp>; +def SUB_RN_SAT_F16X2_NODE : + SDNode<"NVPTXISD::SUB_RN_SAT_F16X2", SDTFPBinOp>; +def SUB_RN_FTZ_SAT_F16X2_NODE : + SDNode<"NVPTXISD::SUB_RN_FTZ_SAT_F16X2", SDTFPBinOp>; + +def INT_NVVM_SUB_RN_SAT_F16 : + BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b), + "sub.rn.sat.f16", + [(set f16:$dst, (SUB_RN_SAT_F16_NODE f16:$a, f16:$b))]>; + +def INT_NVVM_SUB_RN_FTZ_SAT_F16 : + BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B16:$b), + "sub.rn.ftz.sat.f16", + [(set f16:$dst, (SUB_RN_FTZ_SAT_F16_NODE f16:$a, f16:$b))]>; + +def INT_NVVM_SUB_RN_SAT_F16X2 : + BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), + "sub.rn.sat.f16x2", + [(set v2f16:$dst, (SUB_RN_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>; + +def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : + BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), + "sub.rn.ftz.sat.f16x2", + [(set v2f16:$dst, (SUB_RN_FTZ_SAT_F16X2_NODE v2f16:$a, v2f16:$b))]>; + // // Mul @@ -6154,3 +6178,4 @@ foreach sp = [0, 1] in { } } } + diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll index 2c02f6aa3160e..035c36553605d 100644 --- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -1,6 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 ; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s ; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} +; RUN: %if ptxas-isa-6.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | %ptxas-verify%} define half @sub_rn_sat_f16(half %a, half %b) { ; CHECK-LABEL: sub_rn_sat_f16( @@ -13,8 +15,9 @@ define half @sub_rn_sat_f16(half %a, half %b) { ; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2; ; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; ; CHECK-NEXT: ret; - %1 = call half @llvm.nvvm.sub.rn.sat.f16(half %a, half %b) - ret half %1 + %1 = fneg half %b + %res = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %1) + ret half %res } define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { @@ -28,8 +31,9 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.sub.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) - ret <2 x half> %1 + %1 = fneg <2 x half> %b + %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1) + ret <2 x half> %res } define half @sub_rn_ftz_sat_f16(half %a, half %b) { @@ -43,8 +47,9 @@ define half @sub_rn_ftz_sat_f16(half %a, half %b) { ; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2; ; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; ; CHECK-NEXT: ret; - %1 = call half @llvm.nvvm.sub.rn.ftz.sat.f16(half %a, half %b) - ret half %1 + %1 = fneg half %b + %res = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %1) + ret half %res } define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { @@ -58,6 +63,7 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.sub.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) - ret <2 x half> %1 + %1 = fneg <2 x half> %b + %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1) + ret <2 x half> %res } From 5a167f8f318e481118a728a4fda00d6d2413731b Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 2 Dec 2025 13:20:04 +0000 Subject: [PATCH 5/9] fix formatting --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 +++++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 4636ef5fc88b5..eae3d4684798d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -866,15 +866,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand); // We have some custom DAG combine patterns for these nodes - setTargetDAGCombine( - {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, - ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, - ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM, - ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL, - ISD::SREM, ISD::UREM, ISD::VSELECT, - ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND, - ISD::INTRINSIC_WO_CHAIN}); + setTargetDAGCombine({ISD::ADD, + ISD::AND, + ISD::EXTRACT_VECTOR_ELT, + ISD::FADD, + ISD::FMAXNUM, + ISD::FMINNUM, + ISD::FMAXIMUM, + ISD::FMINIMUM, + ISD::FMAXIMUMNUM, + ISD::FMINIMUMNUM, + ISD::MUL, + ISD::SHL, + ISD::SREM, + ISD::UREM, + ISD::VSELECT, + ISD::BUILD_VECTOR, + ISD::ADDRSPACECAST, + ISD::LOAD, + ISD::STORE, + ISD::ZERO_EXTEND, + ISD::SIGN_EXTEND, + ISD::INTRINSIC_WO_CHAIN}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. From 19f0fa0d95eaf1a244a57a7ceea9aec8f4607a9b Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Tue, 2 Dec 2025 13:58:15 +0000 Subject: [PATCH 6/9] rename test appropriately --- clang/test/CodeGen/builtins-nvptx.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 603f25577eb84..c25ff876b6f93 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1528,8 +1528,8 @@ __device__ void nvvm_min_max_sm86() { #define F16X2 {(__fp16)0.1f, (__fp16)0.1f} #define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f} -// CHECK-LABEL: nvvm_add_sub_mul_f16_sat -__device__ void nvvm_add_sub_mul_f16_sat() { +// CHECK-LABEL: nvvm_add_mul_f16_sat +__device__ void nvvm_add_mul_f16_sat() { // CHECK: call half @llvm.nvvm.add.rn.sat.f16 __nvvm_add_rn_sat_f16(F16, F16_2); // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 From e794a56da072f4c4324b15857018f6c73cdc1f04 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 3 Dec 2025 09:08:26 +0000 Subject: [PATCH 7/9] overload fma.rn.oob intrinsics --- clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 36 +++++++++++++++++++ clang/test/CodeGen/builtins-nvptx.c | 8 ++--- llvm/include/llvm/IR/IntrinsicsNVVM.td | 17 ++------- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 41 +++++++++++++--------- llvm/test/CodeGen/NVPTX/fma-oob.ll | 8 ++--- 5 files changed, 71 insertions(+), 39 deletions(-) diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index 8a1cab3417d98..9988faea50d14 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -415,6 +415,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF); } +static Value *MakeFMAOOB(unsigned IntrinsicID, llvm::Type *Ty, + const CallExpr *E, CodeGenFunction &CGF) { + return CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(IntrinsicID, {Ty}), + {CGF.EmitScalarExpr(E->getArg(0)), + CGF.EmitScalarExpr(E->getArg(1)), + CGF.EmitScalarExpr(E->getArg(2))}); +} + } // namespace Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, @@ -963,6 +971,34 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16, BuiltinID, E, *this); case NVPTX::BI__nvvm_fma_rn_sat_f16x2: return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16x2, BuiltinID, E, *this); + case NVPTX::BI__nvvm_fma_rn_oob_f16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getHalfTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_f16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, + llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_bf16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getBFloatTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_bf16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, + llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_f16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), + E, *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, + llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, + Builder.getBFloatTy(), E, *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, + llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E, + *this); case NVPTX::BI__nvvm_fmax_f16: return MakeHalfType(Intrinsic::nvvm_fmax_f16, BuiltinID, E, *this); case NVPTX::BI__nvvm_fmax_f16x2: diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index c25ff876b6f93..4e123ec7617a3 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1558,18 +1558,18 @@ __device__ void nvvm_fma_oob() { __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2); // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16 __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2); - // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2 + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16 __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2); - // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2 + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16 __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2); // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16 __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2); // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16 __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2); - // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2 + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16 __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2); - // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2 + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16 __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2); #endif // CHECK: ret void diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index d24511f371e02..97ae8fad0781a 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1506,21 +1506,8 @@ let TargetPrefix = "nvvm" in { // oob (out-of-bounds) - clamps the result to 0 if either of the operands is // OOB NaN value. foreach relu = ["", "_relu"] in { - def int_nvvm_fma_rn_oob # relu # _f16 : NVVMBuiltin, - PureIntrinsic<[llvm_half_ty], - [llvm_half_ty, llvm_half_ty, llvm_half_ty]>; - - def int_nvvm_fma_rn_oob # relu # _f16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2f16_ty], - [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; - - def int_nvvm_fma_rn_oob # relu # _bf16 : NVVMBuiltin, - PureIntrinsic<[llvm_bfloat_ty], - [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; - - def int_nvvm_fma_rn_oob # relu # _bf16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2bf16_ty], - [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; } // relu foreach rnd = ["rn", "rz", "rm", "rp"] in { diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index f5ca88c9cc717..60cd78dc56eae 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1691,18 +1691,10 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_oob_f16", int_nvvm_fma_rn_oob_f16, B16, - [hasPTX<81>, hasSM<90>]>, - FMA_TUPLE<"_rn_oob_relu_f16", int_nvvm_fma_rn_oob_relu_f16, B16, - [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_oob_bf16", int_nvvm_fma_rn_oob_bf16, B16, - [hasPTX<81>, hasSM<90>]>, - FMA_TUPLE<"_rn_oob_relu_bf16", int_nvvm_fma_rn_oob_relu_bf16, B16, - [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32, [hasPTX<42>, hasSM<53>]>, @@ -1716,19 +1708,11 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, B32, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_oob_f16x2", int_nvvm_fma_rn_oob_f16x2, B32, - [hasPTX<81>, hasSM<90>]>, - FMA_TUPLE<"_rn_oob_relu_f16x2", int_nvvm_fma_rn_oob_relu_f16x2, B32, - [hasPTX<81>, hasSM<90>]>, FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_oob_bf16x2", int_nvvm_fma_rn_oob_bf16x2, B32, - [hasPTX<81>, hasSM<90>]>, - FMA_TUPLE<"_rn_oob_relu_bf16x2", int_nvvm_fma_rn_oob_relu_bf16x2, B32, - [hasPTX<81>, hasSM<90>]>, ] in { def P.Variant : F_MATH_3 : + BasicNVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + "fma.rn.oob" # suffix>; + +class FMA_OOB_TYPE { + ValueType Type = VT; + NVPTXRegClass RegClass = RC; + string TypeStr = TypeName; +} + +let Predicates = [hasPTX<81>, hasSM<90>] in { + foreach relu = ["", "_relu"] in { + foreach ty = [ + FMA_OOB_TYPE, + FMA_OOB_TYPE, + FMA_OOB_TYPE, + FMA_OOB_TYPE + ] in { + defvar Intr = !cast("int_nvvm_fma_rn_oob" # relu); + defvar suffix = !subst("_", ".", relu # "_" # ty.TypeStr); + def : Pat<(ty.Type (Intr ty.Type:$a, ty.Type:$b, ty.Type:$c)), + (FMA_OOB_INST $a, $b, $c)>; + } + } +} // // Rcp // diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll index 2553c5f298b17..7fd9ae13d1998 100644 --- a/llvm/test/CodeGen/NVPTX/fma-oob.ll +++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll @@ -46,7 +46,7 @@ define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { ; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3; ; CHECK-NEXT: st.param.b32 [func_retval0], %r4; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c) ret <2 x half> %1 } @@ -62,7 +62,7 @@ define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> % ; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3; ; CHECK-NEXT: st.param.b32 [func_retval0], %r4; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.f16x2( <2 x half> %a, <2 x half> %b, <2 x half> %c) + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c) ret <2 x half> %1 } @@ -110,7 +110,7 @@ define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloa ; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3; ; CHECK-NEXT: st.param.b32 [func_retval0], %r4; ; CHECK-NEXT: ret; - %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) ret <2 x bfloat> %1 } @@ -126,6 +126,6 @@ define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x ; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3; ; CHECK-NEXT: st.param.b32 [func_retval0], %r4; ; CHECK-NEXT: ret; - %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.bf16x2( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) ret <2 x bfloat> %1 } From bf3fa9259bd4e16a05826c0b0ff04cd95480a7e9 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 3 Dec 2025 09:40:56 +0000 Subject: [PATCH 8/9] rename f16x2 to v2f16 in new intrinsic names --- clang/include/clang/Basic/BuiltinsNVPTX.td | 8 ++++---- clang/test/CodeGen/builtins-nvptx.c | 16 ++++++++-------- llvm/include/llvm/IR/IntrinsicsNVVM.td | 4 ++-- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 9 +++++---- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 8 ++++---- llvm/test/CodeGen/NVPTX/f16-add-sat.ll | 4 ++-- llvm/test/CodeGen/NVPTX/f16-mul-sat.ll | 4 ++-- llvm/test/CodeGen/NVPTX/f16-sub-sat.ll | 4 ++-- 8 files changed, 29 insertions(+), 28 deletions(-) diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index 5ab79b326ee0f..62b528da8440e 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -456,8 +456,8 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">; def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; -def __nvvm_add_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; -def __nvvm_add_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_add_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">; @@ -477,8 +477,8 @@ def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; -def __nvvm_mul_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; -def __nvvm_mul_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_mul_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; // Convert diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 4e123ec7617a3..7c2a71dd5abd5 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1534,19 +1534,19 @@ __device__ void nvvm_add_mul_f16_sat() { __nvvm_add_rn_sat_f16(F16, F16_2); // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 __nvvm_add_rn_ftz_sat_f16(F16, F16_2); - // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.f16x2 - __nvvm_add_rn_sat_f16x2(F16X2, F16X2_2); - // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2 - __nvvm_add_rn_ftz_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.v2f16 + __nvvm_add_rn_sat_v2f16(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16 + __nvvm_add_rn_ftz_sat_v2f16(F16X2, F16X2_2); // CHECK: call half @llvm.nvvm.mul.rn.sat.f16 __nvvm_mul_rn_sat_f16(F16, F16_2); // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16 __nvvm_mul_rn_ftz_sat_f16(F16, F16_2); - // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2 - __nvvm_mul_rn_sat_f16x2(F16X2, F16X2_2); - // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2 - __nvvm_mul_rn_ftz_sat_f16x2(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16 + __nvvm_mul_rn_sat_v2f16(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16 + __nvvm_mul_rn_ftz_sat_v2f16(F16X2, F16X2_2); // CHECK: ret void } diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 97ae8fad0781a..201aad321a331 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1580,7 +1580,7 @@ let TargetPrefix = "nvvm" in { def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin, PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; - def int_nvvm_add_rn # ftz # _sat_f16x2 : NVVMBuiltin, + def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin, PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; } // ftz @@ -1603,7 +1603,7 @@ let TargetPrefix = "nvvm" in { def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin, PureIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; - def int_nvvm_mul_rn # ftz # _sat_f16x2 : NVVMBuiltin, + def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin, PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; } // ftz diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index eae3d4684798d..df1f3f680641c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -6537,17 +6537,18 @@ static SDValue combineIntrinsicWOChain(SDNode *N, unsigned IntID = N->getConstantOperandVal(0); switch (IntID) { + default: + break; case Intrinsic::nvvm_add_rn_sat_f16: return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16); case Intrinsic::nvvm_add_rn_ftz_sat_f16: return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16); - case Intrinsic::nvvm_add_rn_sat_f16x2: + case Intrinsic::nvvm_add_rn_sat_v2f16: return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_SAT_F16X2); - case Intrinsic::nvvm_add_rn_ftz_sat_f16x2: + case Intrinsic::nvvm_add_rn_ftz_sat_v2f16: return combineAddSatWithNeg(N, DCI.DAG, NVPTXISD::SUB_RN_FTZ_SAT_F16X2); - default: - return SDValue(); } + return SDValue(); } static SDValue combineProxyReg(SDNode *N, diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 60cd78dc56eae..a297803761072 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1847,8 +1847,8 @@ let Predicates = [doRsqrtOpt] in { def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>; def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>; -def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_f16x2>; -def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f16x2>; +def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_v2f16>; +def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_v2f16>; def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; @@ -1903,8 +1903,8 @@ def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>; def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>; -def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_f16x2>; -def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_f16x2>; +def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>; // // BFIND diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll index a623d6e5351ab..c2ffc126694c4 100644 --- a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll @@ -28,7 +28,7 @@ define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } @@ -58,6 +58,6 @@ define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll index 68caac8c36e31..4bcc018f290d7 100644 --- a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll @@ -28,7 +28,7 @@ define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.f16x2(<2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } @@ -58,6 +58,6 @@ define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2; ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; - %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %b) + %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) ret <2 x half> %1 } diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll index 035c36553605d..774ce7ccb2f95 100644 --- a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -32,7 +32,7 @@ define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; %1 = fneg <2 x half> %b - %res = call <2 x half> @llvm.nvvm.add.rn.sat.f16x2(<2 x half> %a, <2 x half> %1) + %res = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %1) ret <2 x half> %res } @@ -64,6 +64,6 @@ define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { ; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; %1 = fneg <2 x half> %b - %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.f16x2(<2 x half> %a, <2 x half> %1) + %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %1) ret <2 x half> %res } From d671857715c7f46c78b92df6d1c8cbd79ea7bbec Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Wed, 3 Dec 2025 09:48:52 +0000 Subject: [PATCH 9/9] fix formatting --- clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index 9988faea50d14..eb027cee601ac 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -986,15 +986,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E, *this); case NVPTX::BI__nvvm_fma_rn_oob_relu_f16: - return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), - E, *this); + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), E, + *this); case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2: return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E, *this); case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16: - return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, - Builder.getBFloatTy(), E, *this); + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getBFloatTy(), E, + *this); case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2: return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E,