From d54f68ebe777f7ae10485ca4401dfd2acf5c2766 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Mon, 17 Nov 2025 09:14:43 +0000 Subject: [PATCH 1/2] [NVPTX] Fix PTX and SM conditions for narrow FP conversions This change fixes the PTX and SM conditions for narrow FP conversion intrinsics. It also adds the `AnyPred` helper class to make it easier to combine multiple predicates with OR. --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 6 +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 62 +++++++++++++----------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index ff9d9723dddea..dbf1d57b160f3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -101,6 +101,12 @@ def PrmtMode : Operand { // NVPTX Instruction Predicate Definitions //===----------------------------------------------------------------------===// +// AnyPred - helper class to create an OR condition between multiple predicates. +class AnyPred predicates> : Predicate<""> { + let CondString = !foldl("false", predicates, acc, pred, + acc # " || (" # pred.CondString # ")"); +} + // Checks PTX version and family-specific and architecture-specific SM versions. // For example, sm_100{f/a} and any future variants in the same family will match // for any PTX version greater than or equal to `PTXVersion`. diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ea69a54e6db37..ee656307b0a98 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -2008,34 +2008,40 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>; def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>; def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), - (CVT_e4m3x2_f32 $a, $b, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), - (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>; -def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b), - (CVT_e5m2x2_f32 $a, $b, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b), - (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>; - -def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a), - (CVT_e4m3x2_f16x2 $a, CvtRN)>; -def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a), - (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>; -def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a), - (CVT_e5m2x2_f16x2 $a, CvtRN)>; -def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a), - (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>; - -def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a), - (CVT_f16x2_e4m3x2 $a, CvtRN)>; -def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a), - (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>; -def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a), - (CVT_f16x2_e5m2x2 $a, CvtRN)>; -def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a), - (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; - -let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in { +let Predicates = [hasPTX<81>, hasSM<89>] in { + def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), + (CVT_e4m3x2_f32 $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), + (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>; + def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b), + (CVT_e5m2x2_f32 $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b), + (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>; + + def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a), + (CVT_e4m3x2_f16x2 $a, CvtRN)>; + def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a), + (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>; + def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a), + (CVT_e5m2x2_f16x2 $a, CvtRN)>; + def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a), + (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>; + + def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a), + (CVT_f16x2_e4m3x2 $a, CvtRN)>; + def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a), + (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>; + def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a), + (CVT_f16x2_e5m2x2 $a, CvtRN)>; + def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a), + (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; +} + +let Predicates = [AnyPred<[ + PTXWithFamilySMs<90, [100, 110, 120]>, + PTXWithFamilySMs<88, [100, 101, 120]>, + PTXWithAccelSMs<86, [100, 101, 120]> + ]>] in { def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b), (CVT_e2m3x2_f32_sf $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b), From 2b57b9be61f5485d8fd99da08961c06cd5ea4c37 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi Date: Mon, 17 Nov 2025 11:39:44 +0000 Subject: [PATCH 2/2] remove AnyPred and define predicate in NVPTXSubTarget.h --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 6 ------ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 8 ++------ llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 21 +++++++++++++++++++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index dbf1d57b160f3..ff9d9723dddea 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -101,12 +101,6 @@ def PrmtMode : Operand { // NVPTX Instruction Predicate Definitions //===----------------------------------------------------------------------===// -// AnyPred - helper class to create an OR condition between multiple predicates. -class AnyPred predicates> : Predicate<""> { - let CondString = !foldl("false", predicates, acc, pred, - acc # " || (" # pred.CondString # ")"); -} - // Checks PTX version and family-specific and architecture-specific SM versions. // For example, sm_100{f/a} and any future variants in the same family will match // for any PTX version greater than or equal to `PTXVersion`. diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ee656307b0a98..789cb0fdcbdb2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -2008,7 +2008,7 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>; def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>; def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>; -let Predicates = [hasPTX<81>, hasSM<89>] in { +let Predicates = [callSubtarget<"hasFP8ConversionSupport">] in { def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), (CVT_e4m3x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), @@ -2037,11 +2037,7 @@ let Predicates = [hasPTX<81>, hasSM<89>] in { (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; } -let Predicates = [AnyPred<[ - PTXWithFamilySMs<90, [100, 110, 120]>, - PTXWithFamilySMs<88, [100, 101, 120]>, - PTXWithAccelSMs<86, [100, 101, 120]> - ]>] in { +let Predicates = [callSubtarget<"hasNarrowFPConversionSupport">] in { def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b), (CVT_e2m3x2_f32_sf $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b), diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 021b1f6d0bf57..f11d331862081 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -177,6 +177,27 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo { hasPTXWithAccelSMs(86, {100, 101}); } + // Checks support for conversions involving e4m3x2 and e5m2x2. + bool hasFP8ConversionSupport() const { + if (PTXVersion >= 81) + return SmVersion >= 89; + + if (PTXVersion >= 78) + return SmVersion >= 90; + + return false; + } + + // Checks support for conversions involving the following types: + // - e2m3x2/e3m2x2 + // - e2m1x2 + // - ue8m0x2 + bool hasNarrowFPConversionSupport() const { + return hasPTXWithFamilySMs(90, {100, 110, 120}) || + hasPTXWithFamilySMs(88, {100, 101, 120}) || + hasPTXWithAccelSMs(86, {100, 101, 120}); + } + // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction // terminates a basic block. Instead, it would assume that control flow // continued to the next instruction. The next instruction could be in the