Skip to content

Conversation

LewisCrawford
Copy link
Contributor

Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).

Add DAGCombiner patterns for pairs of 2-operand min/max instructions
to be fused into a single 3-operand min/max instruction for f32s
(only for PTX 8.8+ and sm100+).
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Lewis Crawford (LewisCrawford)

Changes

Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).


Full diff: https://github.com/llvm/llvm-project/pull/159729.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+66-4)
  • (added) llvm/test/CodeGen/NVPTX/fmax3.ll (+260)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..307e1c6f7c227 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
 
   // We have some custom DAG combine patterns for these nodes
-  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
-                       ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
-                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
-                       ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+  setTargetDAGCombine(
+      {ISD::ADD,          ISD::AND,           ISD::EXTRACT_VECTOR_ELT,
+       ISD::FADD,         ISD::FMAXNUM,       ISD::FMINNUM,
+       ISD::FMAXIMUM,     ISD::FMINIMUM,      ISD::FMAXIMUMNUM,
+       ISD::FMINIMUMNUM,  ISD::MUL,           ISD::SHL,
+       ISD::SREM,         ISD::UREM,          ISD::VSELECT,
+       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
+       ISD::STORE,        ISD::ZERO_EXTEND,   ISD::SIGN_EXTEND});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
+/// Get 3-input version of a 2-input min/max opcode
+static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
+  switch (MinMax2Opcode) {
+  case ISD::FMAXNUM:
+  case ISD::FMAXIMUMNUM:
+    return NVPTXISD::FMAXNUM3;
+  case ISD::FMINNUM:
+  case ISD::FMINIMUMNUM:
+    return NVPTXISD::FMINNUM3;
+  case ISD::FMAXIMUM:
+    return NVPTXISD::FMAXIMUM3;
+  case ISD::FMINIMUM:
+    return NVPTXISD::FMINIMUM3;
+  default:
+    llvm_unreachable("Invalid 2-input min/max opcode");
+  }
+}
+
+/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
+/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
+static SDValue PerformFMinMaxCombine(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     unsigned PTXVersion, unsigned SmVersion) {
+
+  // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  unsigned MinMaxOp2 = N->getOpcode();
+  NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
+
+  if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
+    // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
+    SDValue A = Op0.getOperand(0);
+    SDValue B = Op0.getOperand(1);
+    SDValue C = Op1;
+    return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+  } else if (Op1->getOpcode() == MinMaxOp2 && Op1->hasOneUse()) {
+    // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
+    SDValue A = Op0;
+    SDValue B = Op1.getOperand(0);
+    SDValue C = Op1.getOperand(1);
+    return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+  }
+  return SDValue();
+}
+
 static SDValue PerformREMCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
@@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformEXTRACTCombine(N, DCI);
   case ISD::FADD:
     return PerformFADDCombine(N, DCI, OptLevel);
+  case ISD::FMAXNUM:
+  case ISD::FMINNUM:
+  case ISD::FMAXIMUM:
+  case ISD::FMINIMUM:
+  case ISD::FMAXIMUMNUM:
+  case ISD::FMINIMUMNUM:
+    return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
+                                 STI.getSmVersion());
   case ISD::LOAD:
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
diff --git a/llvm/test/CodeGen/NVPTX/fmax3.ll b/llvm/test/CodeGen/NVPTX/fmax3.ll
new file mode 100644
index 0000000000000..9339b2e247af4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fmax3.ll
@@ -0,0 +1,260 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+
+define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaxnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaxnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaxnum3_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaxnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminnum3_param_2];
+; CHECK-NEXT:    min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaximum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaximum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaximum3_param_2];
+; CHECK-NEXT:    max.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaximum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maximum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminimum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminimum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminimum3_param_2];
+; CHECK-NEXT:    min.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminimum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minimum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximumnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaximumnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaximumnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaximumnum3_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimumnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminimumnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminimumnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminimumnum3_param_2];
+; CHECK-NEXT:    min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminimumnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; Test commuted operands (second operand is the nested operation)
+define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3_commuted(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_bc = call float @llvm.maxnum.f32(float %b, float %c)
+  %max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: Mixed min/max operations should not combine
+define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_minmax_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
+; CHECK-NEXT:    min.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
+; CHECK-NEXT:    max.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+  %max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
+  store float %max_result, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
+define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
+; CHECK-NEXT:    max.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
+; CHECK-NEXT:    max.NaN.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
+  store float %maximum_result, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: f16 should not be combined (only f32 supported)
+define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_f16_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_f16_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_f16_no_combine_param_1];
+; CHECK-NEXT:    max.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    ld.param.b16 %rs4, [test_f16_no_combine_param_2];
+; CHECK-NEXT:    max.f16 %rs5, %rs3, %rs4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_f16_no_combine_param_3];
+; CHECK-NEXT:    st.global.b16 [%rd1], %rs5;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call half @llvm.maxnum.f16(half %a, half %b)
+  %max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
+  store half %max_abc, ptr addrspace(1) %output, align 2
+  ret void
+}
+
+; NEGATIVE TEST: Multiple uses of intermediate result should not combine
+define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
+; CHECK-LABEL: test_multiple_uses_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
+; CHECK-NEXT:    max.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
+; CHECK-NEXT:    max.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r3;
+; CHECK-NEXT:    ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+  ; Multiple uses of %max_ab should prevent combining
+  store float %max_ab, ptr addrspace(1) %output1, align 4
+  store float %max_abc, ptr addrspace(1) %output2, align 4
+  ret void
+}
+
+; Declare all the intrinsics we need
+declare float @llvm.maxnum.f32(float, float) #0
+declare float @llvm.minnum.f32(float, float) #0
+declare float @llvm.maximum.f32(float, float) #0
+declare float @llvm.minimum.f32(float, float) #0
+declare float @llvm.maximumnum.f32(float, float) #0
+declare float @llvm.minimumnum.f32(float, float) #0
+declare half @llvm.maxnum.f16(half, half) #0
+
+attributes #0 = { nounwind readnone speculatable willreturn }

Use . instead of -> operators on the SDValue Op1 for
consistency. Both are equivalent, as the * operator on
an SDValue returns the inner SDNode, and the getOperand
function on an SDValue just calls getOperand on the inner
node anyway, but we should use the same approach consistently.
@LewisCrawford LewisCrawford merged commit 0dc2148 into llvm:main Sep 22, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants