Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);

// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
setTargetDAGCombine(
{ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});

// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
Expand Down Expand Up @@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}

/// Get 3-input version of a 2-input min/max opcode
static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
switch (MinMax2Opcode) {
case ISD::FMAXNUM:
case ISD::FMAXIMUMNUM:
return NVPTXISD::FMAXNUM3;
case ISD::FMINNUM:
case ISD::FMINIMUMNUM:
return NVPTXISD::FMINNUM3;
case ISD::FMAXIMUM:
return NVPTXISD::FMAXIMUM3;
case ISD::FMINIMUM:
return NVPTXISD::FMINIMUM3;
default:
llvm_unreachable("Invalid 2-input min/max opcode");
}
}

/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
static SDValue PerformFMinMaxCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
unsigned PTXVersion, unsigned SmVersion) {

// 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
EVT VT = N->getValueType(0);
if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
return SDValue();

SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
unsigned MinMaxOp2 = N->getOpcode();
NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);

if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
// (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
SDValue A = Op0.getOperand(0);
SDValue B = Op0.getOperand(1);
SDValue C = Op1;
return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
} else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
// (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
SDValue A = Op0;
SDValue B = Op1.getOperand(0);
SDValue C = Op1.getOperand(1);
return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
}
return SDValue();
}

static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
Expand Down Expand Up @@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformEXTRACTCombine(N, DCI);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::FMAXNUM:
case ISD::FMINNUM:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
case ISD::FMAXIMUMNUM:
case ISD::FMINIMUMNUM:
return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
STI.getSmVersion());
case ISD::LOAD:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
Expand Down
260 changes: 260 additions & 0 deletions llvm/test/CodeGen/NVPTX/fmax3.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s

target triple = "nvptx64-nvidia-cuda"
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"

define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fmaxnum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_param_2];
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%max_ab = call float @llvm.maxnum.f32(float %a, float %b)
%max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
store float %max_abc, ptr addrspace(1) %output, align 4
ret void
}

define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fminnum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fminnum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fminnum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fminnum3_param_2];
; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminnum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%min_ab = call float @llvm.minnum.f32(float %a, float %b)
%min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
store float %min_abc, ptr addrspace(1) %output, align 4
ret void
}

define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fmaximum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximum3_param_2];
; CHECK-NEXT: max.NaN.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%max_ab = call float @llvm.maximum.f32(float %a, float %b)
%max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
store float %max_abc, ptr addrspace(1) %output, align 4
ret void
}

define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fminimum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fminimum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fminimum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fminimum3_param_2];
; CHECK-NEXT: min.NaN.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%min_ab = call float @llvm.minimum.f32(float %a, float %b)
%min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
store float %min_abc, ptr addrspace(1) %output, align 4
ret void
}

define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fmaximumnum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximumnum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximumnum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximumnum3_param_2];
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
%max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
store float %max_abc, ptr addrspace(1) %output, align 4
ret void
}

define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fminimumnum3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fminimumnum3_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fminimumnum3_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fminimumnum3_param_2];
; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimumnum3_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
%min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
store float %min_abc, ptr addrspace(1) %output, align 4
ret void
}

; Test commuted operands (second operand is the nested operation)
define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_fmaxnum3_commuted(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r4;
; CHECK-NEXT: ret;
entry:
%max_bc = call float @llvm.maxnum.f32(float %b, float %c)
%max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
store float %max_abc, ptr addrspace(1) %output, align 4
ret void
}

; NEGATIVE TEST: Mixed min/max operations should not combine
define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_mixed_minmax_no_combine(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
; CHECK-NEXT: min.f32 %r3, %r1, %r2;
; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
; CHECK-NEXT: max.f32 %r5, %r3, %r4;
; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r5;
; CHECK-NEXT: ret;
entry:
%min_ab = call float @llvm.minnum.f32(float %a, float %b)
%max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
store float %max_result, ptr addrspace(1) %output, align 4
ret void
}

; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
; CHECK-NEXT: max.f32 %r3, %r1, %r2;
; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
; CHECK-NEXT: max.NaN.f32 %r5, %r3, %r4;
; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r5;
; CHECK-NEXT: ret;
entry:
%maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
%maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
store float %maximum_result, ptr addrspace(1) %output, align 4
ret void
}

; NEGATIVE TEST: f16 should not be combined (only f32 supported)
define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
; CHECK-LABEL: test_f16_no_combine(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<6>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b16 %rs1, [test_f16_no_combine_param_0];
; CHECK-NEXT: ld.param.b16 %rs2, [test_f16_no_combine_param_1];
; CHECK-NEXT: max.f16 %rs3, %rs1, %rs2;
; CHECK-NEXT: ld.param.b16 %rs4, [test_f16_no_combine_param_2];
; CHECK-NEXT: max.f16 %rs5, %rs3, %rs4;
; CHECK-NEXT: ld.param.b64 %rd1, [test_f16_no_combine_param_3];
; CHECK-NEXT: st.global.b16 [%rd1], %rs5;
; CHECK-NEXT: ret;
entry:
%max_ab = call half @llvm.maxnum.f16(half %a, half %b)
%max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
store half %max_abc, ptr addrspace(1) %output, align 2
ret void
}

; NEGATIVE TEST: Multiple uses of intermediate result should not combine
define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
; CHECK-LABEL: test_multiple_uses_no_combine(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
; CHECK-NEXT: max.f32 %r3, %r1, %r2;
; CHECK-NEXT: ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
; CHECK-NEXT: max.f32 %r5, %r3, %r4;
; CHECK-NEXT: ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
; CHECK-NEXT: st.global.b32 [%rd1], %r3;
; CHECK-NEXT: ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
; CHECK-NEXT: st.global.b32 [%rd2], %r5;
; CHECK-NEXT: ret;
entry:
%max_ab = call float @llvm.maxnum.f32(float %a, float %b)
%max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
; Multiple uses of %max_ab should prevent combining
store float %max_ab, ptr addrspace(1) %output1, align 4
store float %max_abc, ptr addrspace(1) %output2, align 4
ret void
}

; Declare all the intrinsics we need
declare float @llvm.maxnum.f32(float, float) #0
declare float @llvm.minnum.f32(float, float) #0
declare float @llvm.maximum.f32(float, float) #0
declare float @llvm.minimum.f32(float, float) #0
declare float @llvm.maximumnum.f32(float, float) #0
declare float @llvm.minimumnum.f32(float, float) #0
declare half @llvm.maxnum.f16(half, half) #0

attributes #0 = { nounwind readnone speculatable willreturn }