diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index 6da65b681df1e..5a8dd2153595c 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -375,28 +375,28 @@ static Value *MakeCpAsync(unsigned IntrinsicID, unsigned IntrinsicIDS, CGF.EmitScalarExpr(E->getArg(1))}); } -static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, - const CallExpr *E, CodeGenFunction &CGF) { +static bool EnsureNativeHalfSupport(unsigned BuiltinID, const CallExpr *E, + CodeGenFunction &CGF) { auto &C = CGF.CGM.getContext(); if (!(C.getLangOpts().NativeHalfType || !C.getTargetInfo().useFP16ConversionIntrinsics())) { CGF.CGM.Error(E->getExprLoc(), C.BuiltinInfo.getQuotedName(BuiltinID) + " requires native half type support."); - return nullptr; + return false; } + return true; +} - if (BuiltinID == NVPTX::BI__nvvm_ldg_h || BuiltinID == NVPTX::BI__nvvm_ldg_h2) - return MakeLdg(CGF, E); - - if (IntrinsicID == Intrinsic::nvvm_ldu_global_f) - return MakeLdu(IntrinsicID, CGF, E); +static Value *MakeHalfType(Function *Intrinsic, unsigned BuiltinID, + const CallExpr *E, CodeGenFunction &CGF) { + if (!EnsureNativeHalfSupport(BuiltinID, E, CGF)) + return nullptr; SmallVector Args; - auto *F = CGF.CGM.getIntrinsic(IntrinsicID); - auto *FTy = F->getFunctionType(); + auto *FTy = Intrinsic->getFunctionType(); unsigned ICEArguments = 0; ASTContext::GetBuiltinTypeError Error; - C.GetBuiltinType(BuiltinID, Error, &ICEArguments); + CGF.CGM.getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments); assert(Error == ASTContext::GE_None && "Should not codegen an error"); for (unsigned i = 0, e = E->getNumArgs(); i != e; ++i) { assert((ICEArguments & (1 << i)) == 0); @@ -407,8 +407,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, Args.push_back(ArgValue); } - return CGF.Builder.CreateCall(F, Args); + return CGF.Builder.CreateCall(Intrinsic, Args); +} + +static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, + const CallExpr *E, CodeGenFunction &CGF) { + return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF); } + } // namespace Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, @@ -913,9 +919,14 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, } // The following builtins require half type support case NVPTX::BI__nvvm_ex2_approx_f16: - return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16, BuiltinID, E, *this); + return MakeHalfType( + CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx, Builder.getHalfTy()), + BuiltinID, E, *this); case NVPTX::BI__nvvm_ex2_approx_f16x2: - return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16x2, BuiltinID, E, *this); + return MakeHalfType( + CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx, + FixedVectorType::get(Builder.getHalfTy(), 2)), + BuiltinID, E, *this); case NVPTX::BI__nvvm_ff2f16x2_rn: return MakeHalfType(Intrinsic::nvvm_ff2f16x2_rn, BuiltinID, E, *this); case NVPTX::BI__nvvm_ff2f16x2_rn_relu: @@ -1049,12 +1060,23 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__nvvm_fabs_d: return Builder.CreateUnaryIntrinsic(Intrinsic::fabs, EmitScalarExpr(E->getArg(0))); + case NVPTX::BI__nvvm_ex2_approx_d: + case NVPTX::BI__nvvm_ex2_approx_f: + return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx, + EmitScalarExpr(E->getArg(0))); + case NVPTX::BI__nvvm_ex2_approx_ftz_f: + return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx_ftz, + EmitScalarExpr(E->getArg(0))); case NVPTX::BI__nvvm_ldg_h: case NVPTX::BI__nvvm_ldg_h2: - return MakeHalfType(Intrinsic::not_intrinsic, BuiltinID, E, *this); + if (!EnsureNativeHalfSupport(BuiltinID, E, *this)) + return nullptr; + return MakeLdg(*this, E); case NVPTX::BI__nvvm_ldu_h: case NVPTX::BI__nvvm_ldu_h2: - return MakeHalfType(Intrinsic::nvvm_ldu_global_f, BuiltinID, E, *this); + if (!EnsureNativeHalfSupport(BuiltinID, E, *this)) + return nullptr; + return MakeLdu(Intrinsic::nvvm_ldu_global_f, *this, E); case NVPTX::BI__nvvm_cp_async_ca_shared_global_4: return MakeCpAsync(Intrinsic::nvvm_cp_async_ca_shared_global_4, Intrinsic::nvvm_cp_async_ca_shared_global_4_s, *this, E, diff --git a/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c b/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c index 035c4c6066be2..60a35f4fe0c37 100644 --- a/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c +++ b/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c @@ -8,7 +8,7 @@ typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2))); // CHECK: call half @llvm.nvvm.ex2.approx.f16(half {{.*}}) -// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> {{.*}}) +// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> {{.*}}) // CHECK: call half @llvm.nvvm.fma.rn.relu.f16(half {{.*}}, half {{.*}}, half {{.*}}) // CHECK: call half @llvm.nvvm.fma.rn.ftz.relu.f16(half {{.*}}, half {{.*}}, half {{.*}}) // CHECK: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) diff --git a/clang/test/CodeGen/builtins-nvptx-native-half-type.c b/clang/test/CodeGen/builtins-nvptx-native-half-type.c index 01a004efd71e4..1f16c7e54b85d 100644 --- a/clang/test/CodeGen/builtins-nvptx-native-half-type.c +++ b/clang/test/CodeGen/builtins-nvptx-native-half-type.c @@ -41,7 +41,7 @@ __device__ void nvvm_ex2_sm75() { #if __CUDA_ARCH__ >= 750 // CHECK_PTX70_SM75: call half @llvm.nvvm.ex2.approx.f16 __nvvm_ex2_approx_f16(0.1f16); - // CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.f16x2 + // CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.v2f16 __nvvm_ex2_approx_f16x2({0.1f16, 0.7f16}); #endif // CHECK: ret void diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 719181a09f475..2710853e17688 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1334,15 +1334,8 @@ let TargetPrefix = "nvvm" in { // let IntrProperties = [IntrNoMem] in { foreach ftz = ["", "_ftz"] in - def int_nvvm_ex2_approx # ftz # _f : NVVMBuiltin, - DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty]>; - - def int_nvvm_ex2_approx_d : NVVMBuiltin, - DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty]>; - def int_nvvm_ex2_approx_f16 : - DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty]>; - def int_nvvm_ex2_approx_f16x2 : - DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty]>; + def int_nvvm_ex2_approx # ftz : + DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; foreach ftz = ["", "_ftz"] in def int_nvvm_lg2_approx # ftz # _f : NVVMBuiltin, diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index b838e36c8824f..4d4e9f9b31fcf 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1504,6 +1504,10 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn, else if (Name.consume_front("fabs.")) // nvvm.fabs.{f,ftz.f,d} Expand = Name == "f" || Name == "ftz.f" || Name == "d"; + else if (Name.consume_front("ex2.approx.")) + // nvvm.ex2.approx.{f,ftz.f,d,f16x2} + Expand = + Name == "f" || Name == "ftz.f" || Name == "d" || Name == "f16x2"; else if (Name.consume_front("max.") || Name.consume_front("min.")) // nvvm.{min,max}.{i,ii,ui,ull} Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" || @@ -2550,6 +2554,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, Intrinsic::ID IID = (Name == "fabs.ftz.f") ? Intrinsic::nvvm_fabs_ftz : Intrinsic::nvvm_fabs; Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0)); + } else if (Name.consume_front("ex2.approx.")) { + // nvvm.ex2.approx.{f,ftz.f,d,f16x2} + Intrinsic::ID IID = Name.starts_with("ftz") ? Intrinsic::nvvm_ex2_approx_ftz + : Intrinsic::nvvm_ex2_approx; + Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0)); } else if (Name.starts_with("atomic.load.add.f32.p") || Name.starts_with("atomic.load.add.f64.p")) { Value *Ptr = CI->getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c923f0ec907e7..8ff6cae94dd4b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1605,12 +1605,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>; // Exp2 Log2 // -def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>; -def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>; +def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>; +def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>; let Predicates = [hasPTX<70>, hasSM<75>] in { - def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>; - def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>; + def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>; + def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>; +} + +let Predicates = [hasPTX<78>, hasSM<90>] in { + def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>; + def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>; } def LG2_APPROX_f32 : diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 4029e143ae2a4..37f53f5b6f0a2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -318,7 +318,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC, // answer. These include: // // - nvvm_cos_approx_{f,ftz_f} - // - nvvm_ex2_approx_{d,f,ftz_f} + // - nvvm_ex2_approx(_ftz) // - nvvm_lg2_approx_{d,f,ftz_f} // - nvvm_sin_approx_{f,ftz_f} // - nvvm_sqrt_approx_{f,ftz_f} diff --git a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll index 362586af4f9b7..4fc506f1f5edf 100644 --- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll @@ -87,6 +87,11 @@ declare void @llvm.nvvm.barrier(i32, i32) declare void @llvm.nvvm.barrier.sync(i32) declare void @llvm.nvvm.barrier.sync.cnt(i32, i32) +declare float @llvm.nvvm.ex2.approx.f(float) +declare double @llvm.nvvm.ex2.approx.d(double) +declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>) +declare float @llvm.nvvm.ex2.approx.ftz.f(float) + ; CHECK-LABEL: @simple_upgrade define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) @@ -355,3 +360,15 @@ define void @cta_barriers(i32 %x, i32 %y) { call void @llvm.nvvm.barrier.sync.cnt(i32 %x, i32 %y) ret void } + +define void @nvvm_ex2_approx(float %a, double %b, half %c, <2 x half> %d) { +; CHECK: call float @llvm.nvvm.ex2.approx.f32(float %a) +; CHECK: call double @llvm.nvvm.ex2.approx.f64(double %b) +; CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %d) +; CHECK: call float @llvm.nvvm.ex2.approx.ftz.f32(float %a) + %r1 = call float @llvm.nvvm.ex2.approx.f(float %a) + %r2 = call double @llvm.nvvm.ex2.approx.d(double %b) + %r3 = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %d) + %r4 = call float @llvm.nvvm.ex2.approx.ftz.f(float %a) + ret void +} diff --git a/llvm/test/CodeGen/NVPTX/f16-ex2.ll b/llvm/test/CodeGen/NVPTX/f16-ex2.ll index ee79f9d6d056f..af3fe67269205 100644 --- a/llvm/test/CodeGen/NVPTX/f16-ex2.ll +++ b/llvm/test/CodeGen/NVPTX/f16-ex2.ll @@ -1,12 +1,13 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK-FP16 %s -; RUN: %if ptxas-sm_75 && ptxas-isa-7.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %} +; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK-FP16 %s +; RUN: %if ptxas-sm_90 && ptxas-isa-7.8 %{ llc < %s -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %} target triple = "nvptx64-nvidia-cuda" declare half @llvm.nvvm.ex2.approx.f16(half) -declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>) +declare <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half>) +declare bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat) +declare <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat>) -; CHECK-LABEL: ex2_half define half @ex2_half(half %0) { ; CHECK-FP16-LABEL: ex2_half( ; CHECK-FP16: { @@ -21,7 +22,6 @@ define half @ex2_half(half %0) { ret half %res } -; CHECK-LABEL: ex2_2xhalf define <2 x half> @ex2_2xhalf(<2 x half> %0) { ; CHECK-FP16-LABEL: ex2_2xhalf( ; CHECK-FP16: { @@ -32,6 +32,34 @@ define <2 x half> @ex2_2xhalf(<2 x half> %0) { ; CHECK-FP16-NEXT: ex2.approx.f16x2 %r2, %r1; ; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2; ; CHECK-FP16-NEXT: ret; - %res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0) + %res = call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %0) ret <2 x half> %res } + +define bfloat @ex2_bfloat(bfloat %0) { +; CHECK-FP16-LABEL: ex2_bfloat( +; CHECK-FP16: { +; CHECK-FP16-NEXT: .reg .b16 %rs<3>; +; CHECK-FP16-EMPTY: +; CHECK-FP16-NEXT: // %bb.0: +; CHECK-FP16-NEXT: ld.param.b16 %rs1, [ex2_bfloat_param_0]; +; CHECK-FP16-NEXT: ex2.approx.ftz.bf16 %rs2, %rs1; +; CHECK-FP16-NEXT: st.param.b16 [func_retval0], %rs2; +; CHECK-FP16-NEXT: ret; + %res = call bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat %0) + ret bfloat %res +} + +define <2 x bfloat> @ex2_2xbfloat(<2 x bfloat> %0) { +; CHECK-FP16-LABEL: ex2_2xbfloat( +; CHECK-FP16: { +; CHECK-FP16-NEXT: .reg .b32 %r<3>; +; CHECK-FP16-EMPTY: +; CHECK-FP16-NEXT: // %bb.0: +; CHECK-FP16-NEXT: ld.param.b32 %r1, [ex2_2xbfloat_param_0]; +; CHECK-FP16-NEXT: ex2.approx.ftz.bf16x2 %r2, %r1; +; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-FP16-NEXT: ret; + %res = call <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat> %0) + ret <2 x bfloat> %res +} diff --git a/llvm/test/CodeGen/NVPTX/f32-ex2.ll b/llvm/test/CodeGen/NVPTX/f32-ex2.ll index 796d80d3c2c39..97b9d35be371e 100644 --- a/llvm/test/CodeGen/NVPTX/f32-ex2.ll +++ b/llvm/test/CodeGen/NVPTX/f32-ex2.ll @@ -3,7 +3,8 @@ ; RUN: %if ptxas-sm_50 && ptxas-isa-3.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 -mattr=+ptx32 | %ptxas-verify -arch=sm_50 %} target triple = "nvptx-nvidia-cuda" -declare float @llvm.nvvm.ex2.approx.f(float) +declare float @llvm.nvvm.ex2.approx.f32(float) +declare float @llvm.nvvm.ex2.approx.ftz.f32(float) ; CHECK-LABEL: ex2_float define float @ex2_float(float %0) { @@ -16,7 +17,7 @@ define float @ex2_float(float %0) { ; CHECK-NEXT: ex2.approx.f32 %r2, %r1; ; CHECK-NEXT: st.param.b32 [func_retval0], %r2; ; CHECK-NEXT: ret; - %res = call float @llvm.nvvm.ex2.approx.f(float %0) + %res = call float @llvm.nvvm.ex2.approx.f32(float %0) ret float %res } @@ -31,6 +32,6 @@ define float @ex2_float_ftz(float %0) { ; CHECK-NEXT: ex2.approx.ftz.f32 %r2, %r1; ; CHECK-NEXT: st.param.b32 [func_retval0], %r2; ; CHECK-NEXT: ret; - %res = call float @llvm.nvvm.ex2.approx.ftz.f(float %0) + %res = call float @llvm.nvvm.ex2.approx.ftz.f32(float %0) ret float %res }