-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NVPTX] Fix PTX and SM conditions for narrow FP conversions #168680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This change fixes the PTX and SM conditions for narrow FP conversion intrinsics. It also adds the `AnyPred` helper class to make it easier to combine multiple predicates with OR.
|
@llvm/pr-subscribers-backend-nvptx Author: Srinivasa Ravi (Wolfram70) ChangesThis change fixes the PTX and SM conditions for narrow FP Full diff: https://github.com/llvm/llvm-project/pull/168680.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ea69a54e6db37..789cb0fdcbdb2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2008,34 +2008,36 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>;
def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>;
def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
- (CVT_e4m3x2_f32 $a, $b, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
- (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
- (CVT_e5m2x2_f32 $a, $b, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
- (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
-
-def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
- (CVT_e4m3x2_f16x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
- (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
- (CVT_e5m2x2_f16x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
- (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
-
-def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
- (CVT_f16x2_e4m3x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
- (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
- (CVT_f16x2_e5m2x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
- (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
-
-let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
+let Predicates = [callSubtarget<"hasFP8ConversionSupport">] in {
+ def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
+ (CVT_e4m3x2_f32 $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
+ (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
+ (CVT_e5m2x2_f32 $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
+ (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
+
+ def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
+ (CVT_e4m3x2_f16x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
+ (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
+ (CVT_e5m2x2_f16x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
+ (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
+
+ def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
+ (CVT_f16x2_e4m3x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
+ (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
+ (CVT_f16x2_e5m2x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
+ (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
+}
+
+let Predicates = [callSubtarget<"hasNarrowFPConversionSupport">] in {
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b),
(CVT_e2m3x2_f32_sf $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b),
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 021b1f6d0bf57..f11d331862081 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -177,6 +177,27 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
hasPTXWithAccelSMs(86, {100, 101});
}
+ // Checks support for conversions involving e4m3x2 and e5m2x2.
+ bool hasFP8ConversionSupport() const {
+ if (PTXVersion >= 81)
+ return SmVersion >= 89;
+
+ if (PTXVersion >= 78)
+ return SmVersion >= 90;
+
+ return false;
+ }
+
+ // Checks support for conversions involving the following types:
+ // - e2m3x2/e3m2x2
+ // - e2m1x2
+ // - ue8m0x2
+ bool hasNarrowFPConversionSupport() const {
+ return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
+ hasPTXWithFamilySMs(88, {100, 101, 120}) ||
+ hasPTXWithAccelSMs(86, {100, 101, 120});
+ }
+
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
// terminates a basic block. Instead, it would assume that control flow
// continued to the next instruction. The next instruction could be in the
|
🐧 Linux x64 Test Results
|
| // - e2m3x2/e3m2x2 | ||
| // - e2m1x2 | ||
| // - ue8m0x2 | ||
| bool hasNarrowFPConversionSupport() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional:
I wonder if we should name it something like "subbyteFP" instead of "narrowFP".
Change as such LGTM
rajatbajpai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This change fixes the PTX and SM conditions for narrow FP
conversion intrinsics and adds support for family-conditionals.