-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[clang][NVPTX] Add support for mixed-precision FP arithmetic #168359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1376,13 +1376,15 @@ let TargetPrefix = "nvvm" in { | |
| } // ftz | ||
| } // variant | ||
|
|
||
| foreach rnd = ["rn", "rz", "rm", "rp"] in { | ||
| foreach ftz = ["", "_ftz"] in | ||
| def int_nvvm_fma_ # rnd # ftz # _f : NVVMBuiltin, | ||
| PureIntrinsic<[llvm_float_ty], | ||
| [llvm_float_ty, llvm_float_ty, llvm_float_ty]>; | ||
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach ftz = ["", "_ftz"] in { | ||
| foreach sat = ["", "_sat"] in | ||
| def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin, | ||
| PureIntrinsic<[llvm_float_ty], | ||
| [llvm_float_ty, llvm_float_ty, llvm_float_ty]>; | ||
| } | ||
|
|
||
| def int_nvvm_fma_ # rnd # _d : NVVMBuiltin, | ||
| def int_nvvm_fma # rnd # _d : NVVMBuiltin, | ||
| PureIntrinsic<[llvm_double_ty], | ||
| [llvm_double_ty, llvm_double_ty, llvm_double_ty]>; | ||
| } | ||
|
|
@@ -1443,12 +1445,30 @@ let TargetPrefix = "nvvm" in { | |
| // Add | ||
| // | ||
| let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in { | ||
| foreach rnd = ["rn", "rz", "rm", "rp"] in { | ||
| foreach ftz = ["", "_ftz"] in | ||
| def int_nvvm_add_ # rnd # ftz # _f : NVVMBuiltin, | ||
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach ftz = ["", "_ftz"] in { | ||
| foreach sat = ["", "_sat"] in | ||
| def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin, | ||
| DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>; | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need this new line here |
||
| def int_nvvm_add # rnd # _d : NVVMBuiltin, | ||
| DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. } // rnd |
||
| } | ||
|
|
||
| // | ||
| // Sub | ||
| // | ||
| let IntrProperties = [IntrNoMem, IntrSpeculatable] in { | ||
| foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { | ||
| foreach ftz = ["", "_ftz"] in { | ||
| foreach sat = ["", "_sat"] in | ||
| def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin, | ||
| DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>; | ||
| } | ||
|
|
||
| def int_nvvm_add_ # rnd # _d : NVVMBuiltin, | ||
| def int_nvvm_sub # rnd # _d : NVVMBuiltin, | ||
| DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1637,13 +1637,21 @@ multiclass FMA_INST { | |
| FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>, | ||
|
|
||
| FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>, | ||
| FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>, | ||
| FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>, | ||
| FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>, | ||
| FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>, | ||
| FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>, | ||
| FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>, | ||
| FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>, | ||
| FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>, | ||
| FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>, | ||
| FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>, | ||
| FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>, | ||
| FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>, | ||
|
|
||
| FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>, | ||
| FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16, | ||
|
|
@@ -1694,6 +1702,22 @@ multiclass FMA_INST { | |
|
|
||
| defm INT_NVVM_FMA : FMA_INST; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think keeping these values in lower-case will remove a few We could just have a prefix in upper-case (_FMA or _MIXED_FMA or something like that), if you want the Opcode to be more visible. |
||
| def INT_NVVM_FMA # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c), | ||
| !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$b)), | ||
| f32:$c))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // | ||
| // Rcp | ||
| // | ||
|
|
@@ -1793,19 +1817,81 @@ let Predicates = [doRsqrtOpt] in { | |
| // | ||
|
|
||
| def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; | ||
| def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; | ||
| def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>; | ||
| def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>; | ||
| def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>; | ||
| def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>; | ||
| def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>; | ||
| def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>; | ||
| def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>; | ||
| def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>; | ||
| def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>; | ||
| def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>; | ||
|
|
||
| def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>; | ||
| def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>; | ||
| def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; | ||
| def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
| def INT_NVVM_ADD # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), | ||
| !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For FMA's the PTX ISA is clear stating the rounding mode is applied only to the "result". (i.e. the given rounding mode has no impact on the conversion of input from f16 to f32 but only applies to the final result)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I remember, |
||
| f32:$b))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Sub | ||
| // | ||
|
|
||
| def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>; | ||
| def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>; | ||
| def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>; | ||
| def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>; | ||
| def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>; | ||
| def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>; | ||
| def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>; | ||
| def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>; | ||
| def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>; | ||
| def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>; | ||
| def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>; | ||
| def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>; | ||
|
|
||
| def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>; | ||
| def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>; | ||
| def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>; | ||
| def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
| def INT_NVVM_SUB # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), | ||
| !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
| f32:$b))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
| // | ||
| // BFIND | ||
| // | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us add a brace in the end here as well.
Also, add trailing comments on the loop-ending braces.
This way, it is easier to spot the nesting