diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index d3fb657851fe2..ca8a3f69f991d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand); // We have some custom DAG combine patterns for these nodes - setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, - ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, - ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, - ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); + setTargetDAGCombine( + {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, + ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, + ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM, + ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL, + ISD::SREM, ISD::UREM, ISD::VSELECT, + ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD, + ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N, return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel); } +/// Get 3-input version of a 2-input min/max opcode +static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) { + switch (MinMax2Opcode) { + case ISD::FMAXNUM: + case ISD::FMAXIMUMNUM: + return NVPTXISD::FMAXNUM3; + case ISD::FMINNUM: + case ISD::FMINIMUMNUM: + return NVPTXISD::FMINNUM3; + case ISD::FMAXIMUM: + return NVPTXISD::FMAXIMUM3; + case ISD::FMINIMUM: + return NVPTXISD::FMINIMUM3; + default: + llvm_unreachable("Invalid 2-input min/max opcode"); + } +} + +/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into +/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics. +static SDValue PerformFMinMaxCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + unsigned PTXVersion, unsigned SmVersion) { + + // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s + EVT VT = N->getValueType(0); + if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100) + return SDValue(); + + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + unsigned MinMaxOp2 = N->getOpcode(); + NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2); + + if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) { + // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c) + SDValue A = Op0.getOperand(0); + SDValue B = Op0.getOperand(1); + SDValue C = Op1; + return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags()); + } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) { + // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c) + SDValue A = Op0; + SDValue B = Op1.getOperand(0); + SDValue C = Op1.getOperand(1); + return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags()); + } + return SDValue(); +} + static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel) { @@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformEXTRACTCombine(N, DCI); case ISD::FADD: return PerformFADDCombine(N, DCI, OptLevel); + case ISD::FMAXNUM: + case ISD::FMINNUM: + case ISD::FMAXIMUM: + case ISD::FMINIMUM: + case ISD::FMAXIMUMNUM: + case ISD::FMINIMUMNUM: + return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(), + STI.getSmVersion()); case ISD::LOAD: case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: diff --git a/llvm/test/CodeGen/NVPTX/fmax3.ll b/llvm/test/CodeGen/NVPTX/fmax3.ll new file mode 100644 index 0000000000000..9339b2e247af4 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fmax3.ll @@ -0,0 +1,260 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s + +target triple = "nvptx64-nvidia-cuda" +target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fmaxnum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_param_2]; +; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %max_ab = call float @llvm.maxnum.f32(float %a, float %b) + %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c) + store float %max_abc, ptr addrspace(1) %output, align 4 + ret void +} + +define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fminnum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fminnum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fminnum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fminnum3_param_2]; +; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fminnum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %min_ab = call float @llvm.minnum.f32(float %a, float %b) + %min_abc = call float @llvm.minnum.f32(float %min_ab, float %c) + store float %min_abc, ptr addrspace(1) %output, align 4 + ret void +} + +define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fmaximum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximum3_param_2]; +; CHECK-NEXT: max.NaN.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %max_ab = call float @llvm.maximum.f32(float %a, float %b) + %max_abc = call float @llvm.maximum.f32(float %max_ab, float %c) + store float %max_abc, ptr addrspace(1) %output, align 4 + ret void +} + +define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fminimum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fminimum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fminimum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fminimum3_param_2]; +; CHECK-NEXT: min.NaN.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %min_ab = call float @llvm.minimum.f32(float %a, float %b) + %min_abc = call float @llvm.minimum.f32(float %min_ab, float %c) + store float %min_abc, ptr addrspace(1) %output, align 4 + ret void +} + +define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fmaximumnum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximumnum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximumnum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximumnum3_param_2]; +; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximumnum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %max_ab = call float @llvm.maximumnum.f32(float %a, float %b) + %max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c) + store float %max_abc, ptr addrspace(1) %output, align 4 + ret void +} + +define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fminimumnum3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fminimumnum3_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fminimumnum3_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fminimumnum3_param_2]; +; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimumnum3_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %min_ab = call float @llvm.minimumnum.f32(float %a, float %b) + %min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c) + store float %min_abc, ptr addrspace(1) %output, align 4 + ret void +} + +; Test commuted operands (second operand is the nested operation) +define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_fmaxnum3_commuted( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2]; +; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r4; +; CHECK-NEXT: ret; +entry: + %max_bc = call float @llvm.maxnum.f32(float %b, float %c) + %max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc) + store float %max_abc, ptr addrspace(1) %output, align 4 + ret void +} + +; NEGATIVE TEST: Mixed min/max operations should not combine +define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_mixed_minmax_no_combine( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<6>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1]; +; CHECK-NEXT: min.f32 %r3, %r1, %r2; +; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2]; +; CHECK-NEXT: max.f32 %r5, %r3, %r4; +; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r5; +; CHECK-NEXT: ret; +entry: + %min_ab = call float @llvm.minnum.f32(float %a, float %b) + %max_result = call float @llvm.maxnum.f32(float %min_ab, float %c) + store float %max_result, ptr addrspace(1) %output, align 4 + ret void +} + +; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine +define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<6>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1]; +; CHECK-NEXT: max.f32 %r3, %r1, %r2; +; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2]; +; CHECK-NEXT: max.NaN.f32 %r5, %r3, %r4; +; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r5; +; CHECK-NEXT: ret; +entry: + %maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b) + %maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c) + store float %maximum_result, ptr addrspace(1) %output, align 4 + ret void +} + +; NEGATIVE TEST: f16 should not be combined (only f32 supported) +define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) { +; CHECK-LABEL: test_f16_no_combine( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<6>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b16 %rs1, [test_f16_no_combine_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_f16_no_combine_param_1]; +; CHECK-NEXT: max.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: ld.param.b16 %rs4, [test_f16_no_combine_param_2]; +; CHECK-NEXT: max.f16 %rs5, %rs3, %rs4; +; CHECK-NEXT: ld.param.b64 %rd1, [test_f16_no_combine_param_3]; +; CHECK-NEXT: st.global.b16 [%rd1], %rs5; +; CHECK-NEXT: ret; +entry: + %max_ab = call half @llvm.maxnum.f16(half %a, half %b) + %max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c) + store half %max_abc, ptr addrspace(1) %output, align 2 + ret void +} + +; NEGATIVE TEST: Multiple uses of intermediate result should not combine +define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) { +; CHECK-LABEL: test_multiple_uses_no_combine( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<6>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1]; +; CHECK-NEXT: max.f32 %r3, %r1, %r2; +; CHECK-NEXT: ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2]; +; CHECK-NEXT: max.f32 %r5, %r3, %r4; +; CHECK-NEXT: ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3]; +; CHECK-NEXT: st.global.b32 [%rd1], %r3; +; CHECK-NEXT: ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4]; +; CHECK-NEXT: st.global.b32 [%rd2], %r5; +; CHECK-NEXT: ret; +entry: + %max_ab = call float @llvm.maxnum.f32(float %a, float %b) + %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c) + ; Multiple uses of %max_ab should prevent combining + store float %max_ab, ptr addrspace(1) %output1, align 4 + store float %max_abc, ptr addrspace(1) %output2, align 4 + ret void +} + +; Declare all the intrinsics we need +declare float @llvm.maxnum.f32(float, float) #0 +declare float @llvm.minnum.f32(float, float) #0 +declare float @llvm.maximum.f32(float, float) #0 +declare float @llvm.minimum.f32(float, float) #0 +declare float @llvm.maximumnum.f32(float, float) #0 +declare float @llvm.minimumnum.f32(float, float) #0 +declare half @llvm.maxnum.f16(half, half) #0 + +attributes #0 = { nounwind readnone speculatable willreturn }