-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[clang][NVPTX] Add support for mixed-precision FP arithmetic #168359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…hmetic This change adds NVVM intrinsics and clang builtins for mixed-precision FP arithmetic instructions. Tests are added in `mixed-precision-fp.ll` and `builtins-nvptx.c` and verified through `ptxas-13.0`. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions
|
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-clang Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds NVVM intrinsics and clang builtins for mixed-precision Tests are added in PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions Patch is 37.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168359.diff 6 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td
index d923d2a90e908..47ba12bef058c 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">;
+def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>;
+
+def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>;
+
// Rcp
def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">;
@@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;
+def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
+// Sub
+
+def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>;
+
+def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>;
+
// Convert
def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 8a1cab3417d98..6f57620f0fb00 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID,
+ const CallExpr *E,
+ CodeGenFunction &CGF) {
+ SmallVector<llvm::Value *, 3> Args;
+ for (unsigned i = 0; i < E->getNumArgs(); ++i) {
+ Args.push_back(CGF.EmitScalarExpr(E->getArg(i)));
+ }
+ return CGF.Builder.CreateCall(
+ CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args);
+}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(
CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count),
{EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
+ case NVPTX::BI__nvvm_add_mixed_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E,
+ *this);
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32,
+ E, *this);
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32:
+ case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32:
+ return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32,
+ E, *this);
default:
return nullptr;
}
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index e3be262622844..1753b4c7767e9 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1466,3 +1466,136 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}
+
+#define F16 (__fp16)0.1f
+#define F16_2 (__fp16)0.2f
+
+__device__ void nvvm_add_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f);
+#endif
+}
+
+__device__ void nvvm_sub_mixed_precision_sm100() {
+#if __CUDA_ARCH__ >= 1000
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f);
+
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f);
+ // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00)
+ _...
[truncated]
|
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach sat = ["", "_sat"] in { | ||
| def int_nvvm_fma_mixed # rnd # sat # _f32 : | ||
| PureIntrinsic<[llvm_float_ty], | ||
| [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty]>; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need really intrinsics for these operations? Is this equivalent to (fma $a, (fpext $b), $c)? If this can already be simply represented by a couple existing instructions, I'd lean towards using those idioms instead of adding more intrinsics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, that makes a lot of sense. We could do the same for the add and sub operations too since they are equivalent to a conversion followed by the fp operation as well.
It looks like we are missing intrinsics for some variants of these base operations currently (fma and add with .sat and all of sub) so I've just added those and used this pattern to lower to the mixed precision instructions on supported architectures.
Please take a look, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And now that we are no longer adding mixed mode builtins, we should also update the patch title and description.
Never mind. It's got updated.
🐧 Linux x64 Test Results
|
This change adds support for mixed precision floating point
arithmetic for
f16andbf16where the following pattern:is lowered to the corresponding mixed precision instruction which
combines the conversion and operation into one instruction from
sm_100onwards.This also adds the following intrinsics to complete support for
all variants of the floating point
add/sub/fmaoperations in orderto support the corresponding mixed-precision instructions:
llvm.nvvm.add.(rn/rz/rm/rp){.ftz}.sat.fllvm.nvvm.fma.(rn/rz/rm/rp){.ftz}.sat.fllvm.nvvm.sub*Tests are added in
fp-arith-sat.ll,fp-sub-intrins.ll, andbultins-nvptx.cfor the newly added intrinsics and builtins, and in
mixed-precision-fp.llfor the mixed precision instructions.
PTX spec reference for mixed precision instructions: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions