-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NVPTX] Add 3-operand fmin/fmax DAGCombines #159729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).
@llvm/pr-subscribers-backend-nvptx Author: Lewis Crawford (LewisCrawford) ChangesAdd DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+). Full diff: https://github.com/llvm/llvm-project/pull/159729.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..307e1c6f7c227 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
- ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ setTargetDAGCombine(
+ {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
+ ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
+ ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
+ ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
+ ISD::SREM, ISD::UREM, ISD::VSELECT,
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
+ ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
+/// Get 3-input version of a 2-input min/max opcode
+static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
+ switch (MinMax2Opcode) {
+ case ISD::FMAXNUM:
+ case ISD::FMAXIMUMNUM:
+ return NVPTXISD::FMAXNUM3;
+ case ISD::FMINNUM:
+ case ISD::FMINIMUMNUM:
+ return NVPTXISD::FMINNUM3;
+ case ISD::FMAXIMUM:
+ return NVPTXISD::FMAXIMUM3;
+ case ISD::FMINIMUM:
+ return NVPTXISD::FMINIMUM3;
+ default:
+ llvm_unreachable("Invalid 2-input min/max opcode");
+ }
+}
+
+/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
+/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
+static SDValue PerformFMinMaxCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned PTXVersion, unsigned SmVersion) {
+
+ // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
+ return SDValue();
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ unsigned MinMaxOp2 = N->getOpcode();
+ NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
+
+ if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
+ // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
+ SDValue A = Op0.getOperand(0);
+ SDValue B = Op0.getOperand(1);
+ SDValue C = Op1;
+ return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+ } else if (Op1->getOpcode() == MinMaxOp2 && Op1->hasOneUse()) {
+ // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
+ SDValue A = Op0;
+ SDValue B = Op1.getOperand(0);
+ SDValue C = Op1.getOperand(1);
+ return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+ }
+ return SDValue();
+}
+
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformEXTRACTCombine(N, DCI);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
+ case ISD::FMAXNUM:
+ case ISD::FMINNUM:
+ case ISD::FMAXIMUM:
+ case ISD::FMINIMUM:
+ case ISD::FMAXIMUMNUM:
+ case ISD::FMINIMUMNUM:
+ return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
+ STI.getSmVersion());
case ISD::LOAD:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
diff --git a/llvm/test/CodeGen/NVPTX/fmax3.ll b/llvm/test/CodeGen/NVPTX/fmax3.ll
new file mode 100644
index 0000000000000..9339b2e247af4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fmax3.ll
@@ -0,0 +1,260 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+
+define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminnum3_param_2];
+; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximum3_param_2];
+; CHECK-NEXT: max.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maximum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminimum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminimum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminimum3_param_2];
+; CHECK-NEXT: min.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minimum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximumnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximumnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximumnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximumnum3_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimumnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminimumnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminimumnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminimumnum3_param_2];
+; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimumnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; Test commuted operands (second operand is the nested operation)
+define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3_commuted(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_bc = call float @llvm.maxnum.f32(float %b, float %c)
+ %max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: Mixed min/max operations should not combine
+define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_minmax_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
+; CHECK-NEXT: min.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
+; CHECK-NEXT: max.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+ %max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
+ store float %max_result, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
+define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
+; CHECK-NEXT: max.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
+; CHECK-NEXT: max.NaN.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
+ store float %maximum_result, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: f16 should not be combined (only f32 supported)
+define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_f16_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b16 %rs1, [test_f16_no_combine_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_f16_no_combine_param_1];
+; CHECK-NEXT: max.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: ld.param.b16 %rs4, [test_f16_no_combine_param_2];
+; CHECK-NEXT: max.f16 %rs5, %rs3, %rs4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_f16_no_combine_param_3];
+; CHECK-NEXT: st.global.b16 [%rd1], %rs5;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call half @llvm.maxnum.f16(half %a, half %b)
+ %max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
+ store half %max_abc, ptr addrspace(1) %output, align 2
+ ret void
+}
+
+; NEGATIVE TEST: Multiple uses of intermediate result should not combine
+define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
+; CHECK-LABEL: test_multiple_uses_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
+; CHECK-NEXT: max.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
+; CHECK-NEXT: max.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r3;
+; CHECK-NEXT: ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
+; CHECK-NEXT: st.global.b32 [%rd2], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+ ; Multiple uses of %max_ab should prevent combining
+ store float %max_ab, ptr addrspace(1) %output1, align 4
+ store float %max_abc, ptr addrspace(1) %output2, align 4
+ ret void
+}
+
+; Declare all the intrinsics we need
+declare float @llvm.maxnum.f32(float, float) #0
+declare float @llvm.minnum.f32(float, float) #0
+declare float @llvm.maximum.f32(float, float) #0
+declare float @llvm.minimum.f32(float, float) #0
+declare float @llvm.maximumnum.f32(float, float) #0
+declare float @llvm.minimumnum.f32(float, float) #0
+declare half @llvm.maxnum.f16(half, half) #0
+
+attributes #0 = { nounwind readnone speculatable willreturn }
|
Use . instead of -> operators on the SDValue Op1 for consistency. Both are equivalent, as the * operator on an SDValue returns the inner SDNode, and the getOperand function on an SDValue just calls getOperand on the inner node anyway, but we should use the same approach consistently.
durga4github
approved these changes
Sep 19, 2025
Artem-B
approved these changes
Sep 19, 2025
Prince781
approved these changes
Sep 20, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).