Skip to content

Commit

Permalink
[AMDGPU] Handle min(max(x, y), max(min(x, y), z)) in med3 combines
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D139508
  • Loading branch information
bogner committed Dec 7, 2022
1 parent 916ae0a commit bcfdaa9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
22 changes: 18 additions & 4 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Expand Up @@ -3228,9 +3228,10 @@ multiclass IntMed3Pat<Instruction med3Inst,
defm : IntMed3Pat<V_MED3_I32_e64, smin, smax>;
defm : IntMed3Pat<V_MED3_U32_e64, umin, umax>;

// This matches 16 permutations of max(min(x, y), min(max(x, y), z))
class FPMed3Pat<ValueType vt,
Instruction med3Inst> : GCNPat<
multiclass FPMed3Pat<ValueType vt,
Instruction med3Inst> {
// This matches 16 permutations of max(min(x, y), min(max(x, y), z))
def : GCNPat<
(fmaxnum_like_nnan
(fminnum_like (VOP3Mods vt:$src0, i32:$src0_mods),
(VOP3Mods vt:$src1, i32:$src1_mods)),
Expand All @@ -3240,6 +3241,19 @@ class FPMed3Pat<ValueType vt,
(med3Inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE, DSTOMOD.NONE)>;


// This matches 16 permutations of min(max(x, y), max(min(x, y), z))
def : GCNPat<
(fminnum_like_nnan
(fmaxnum_like (VOP3Mods vt:$src0, i32:$src0_mods),
(VOP3Mods vt:$src1, i32:$src1_mods)),
(fmaxnum_like (fminnum_like (VOP3Mods vt:$src0, i32:$src0_mods),
(VOP3Mods vt:$src1, i32:$src1_mods)),
(vt (VOP3Mods vt:$src2, i32:$src2_mods)))),
(med3Inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE, DSTOMOD.NONE)>;
}

class FP16Med3Pat<ValueType vt,
Instruction med3Inst> : GCNPat<
(fmaxnum_like_nnan (fminnum_like (VOP3Mods vt:$src0, i32:$src0_mods),
Expand Down Expand Up @@ -3270,7 +3284,7 @@ multiclass Int16Med3Pat<Instruction med3Inst,
>;
}

def : FPMed3Pat<f32, V_MED3_F32_e64>;
defm : FPMed3Pat<f32, V_MED3_F32_e64>;

class
IntMinMaxPat<Instruction minmaxInst, SDPatternOperator min_or_max,
Expand Down
32 changes: 9 additions & 23 deletions llvm/test/CodeGen/AMDGPU/GlobalISel/fmed3.ll
Expand Up @@ -146,10 +146,7 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat1_srcmod0(ptr addrspa
; SI-NEXT: buffer_load_dword v4, v[0:1], s[8:11], 0 addr64 glc
; SI-NEXT: s_waitcnt vmcnt(0)
; SI-NEXT: v_mul_f32_e32 v2, -1.0, v2
; SI-NEXT: v_max_f32_e32 v5, v2, v3
; SI-NEXT: v_min_f32_e32 v2, v2, v3
; SI-NEXT: v_max_f32_e32 v2, v2, v4
; SI-NEXT: v_min_f32_e32 v2, v5, v2
; SI-NEXT: v_med3_f32 v2, v2, v3, v4
; SI-NEXT: s_mov_b64 s[2:3], s[10:11]
; SI-NEXT: buffer_store_dword v2, v[0:1], s[0:3], 0 addr64
; SI-NEXT: s_endpgm
Expand All @@ -169,23 +166,20 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat1_srcmod0(ptr addrspa
; VI-NEXT: v_addc_u32_e32 v3, vcc, 0, v3, vcc
; VI-NEXT: v_mov_b32_e32 v4, s6
; VI-NEXT: v_mov_b32_e32 v5, s7
; VI-NEXT: v_add_u32_e32 v4, vcc, v4, v6
; VI-NEXT: v_addc_u32_e32 v5, vcc, 0, v5, vcc
; VI-NEXT: flat_load_dword v7, v[0:1] glc
; VI-NEXT: s_waitcnt vmcnt(0)
; VI-NEXT: flat_load_dword v2, v[2:3] glc
; VI-NEXT: s_waitcnt vmcnt(0)
; VI-NEXT: v_add_u32_e32 v0, vcc, v4, v6
; VI-NEXT: v_addc_u32_e32 v1, vcc, 0, v5, vcc
; VI-NEXT: flat_load_dword v3, v[0:1] glc
; VI-NEXT: flat_load_dword v3, v[4:5] glc
; VI-NEXT: s_waitcnt vmcnt(0)
; VI-NEXT: v_mov_b32_e32 v0, s0
; VI-NEXT: v_mov_b32_e32 v1, s1
; VI-NEXT: v_add_u32_e32 v0, vcc, v0, v6
; VI-NEXT: v_addc_u32_e32 v1, vcc, 0, v1, vcc
; VI-NEXT: v_mul_f32_e32 v4, -1.0, v7
; VI-NEXT: v_max_f32_e32 v5, v4, v2
; VI-NEXT: v_min_f32_e32 v2, v4, v2
; VI-NEXT: v_max_f32_e32 v2, v2, v3
; VI-NEXT: v_min_f32_e32 v2, v5, v2
; VI-NEXT: v_med3_f32 v2, v4, v2, v3
; VI-NEXT: flat_store_dword v[0:1], v2
; VI-NEXT: s_endpgm
;
Expand All @@ -201,10 +195,7 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat1_srcmod0(ptr addrspa
; GFX9-NEXT: global_load_dword v3, v0, s[6:7] glc
; GFX9-NEXT: s_waitcnt vmcnt(0)
; GFX9-NEXT: v_max_f32_e64 v1, -v1, -v1
; GFX9-NEXT: v_max_f32_e32 v4, v1, v2
; GFX9-NEXT: v_min_f32_e32 v1, v1, v2
; GFX9-NEXT: v_max_f32_e32 v1, v1, v3
; GFX9-NEXT: v_min_f32_e32 v1, v4, v1
; GFX9-NEXT: v_med3_f32 v1, v1, v2, v3
; GFX9-NEXT: global_store_dword v0, v1, s[0:1]
; GFX9-NEXT: s_endpgm
;
Expand All @@ -220,10 +211,7 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat1_srcmod0(ptr addrspa
; GFX10-NEXT: global_load_dword v3, v0, s[6:7] glc dlc
; GFX10-NEXT: s_waitcnt vmcnt(0)
; GFX10-NEXT: v_max_f32_e64 v1, -v1, -v1
; GFX10-NEXT: v_min_f32_e32 v4, v1, v2
; GFX10-NEXT: v_max_f32_e32 v1, v1, v2
; GFX10-NEXT: v_max_f32_e32 v2, v4, v3
; GFX10-NEXT: v_min_f32_e32 v1, v1, v2
; GFX10-NEXT: v_med3_f32 v1, v1, v2, v3
; GFX10-NEXT: global_store_dword v0, v1, s[0:1]
; GFX10-NEXT: s_endpgm
;
Expand All @@ -239,10 +227,8 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat1_srcmod0(ptr addrspa
; GFX11-NEXT: global_load_b32 v3, v0, s[6:7] glc dlc
; GFX11-NEXT: s_waitcnt vmcnt(0)
; GFX11-NEXT: v_max_f32_e64 v1, -v1, -v1
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
; GFX11-NEXT: v_max_f32_e32 v4, v1, v2
; GFX11-NEXT: v_min_f32_e32 v1, v1, v2
; GFX11-NEXT: v_maxmin_f32 v1, v1, v3, v4
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1)
; GFX11-NEXT: v_med3_f32 v1, v1, v2, v3
; GFX11-NEXT: global_store_b32 v0, v1, s[0:1]
; GFX11-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
; GFX11-NEXT: s_endpgm
Expand Down
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/AMDGPU/fmed3.ll
Expand Up @@ -746,6 +746,31 @@ define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat15(float addrspace(1)
ret void
}

; Also handle `min` at the root:
; min(max(x, y), max(min(x, y), z))

; GCN-LABEL: {{^}}v_test_global_nnans_med3_f32_pat16:
; GCN: {{buffer|flat|global}}_load_dword [[A:v[0-9]+]]
; GCN: {{buffer|flat|global}}_load_dword [[B:v[0-9]+]]
; GCN: {{buffer|flat|global}}_load_dword [[C:v[0-9]+]]
; GCN: v_med3_f32 v{{[0-9]+}}, [[A]], [[B]], [[C]]
define amdgpu_kernel void @v_test_global_nnans_med3_f32_pat16(float addrspace(1)* %out, float addrspace(1)* %aptr, float addrspace(1)* %bptr, float addrspace(1)* %cptr) #2 {
%tid = call i32 @llvm.amdgcn.workitem.id.x()
%gep0 = getelementptr float, float addrspace(1)* %aptr, i32 %tid
%gep1 = getelementptr float, float addrspace(1)* %bptr, i32 %tid
%gep2 = getelementptr float, float addrspace(1)* %cptr, i32 %tid
%outgep = getelementptr float, float addrspace(1)* %out, i32 %tid
%a = load volatile float, float addrspace(1)* %gep0
%b = load volatile float, float addrspace(1)* %gep1
%c = load volatile float, float addrspace(1)* %gep2
%tmp0 = call float @llvm.maxnum.f32(float %a, float %b)
%tmp1 = call float @llvm.minnum.f32(float %a, float %b)
%tmp2 = call float @llvm.maxnum.f32(float %tmp1, float %c)
%med3 = call float @llvm.minnum.f32(float %tmp0, float %tmp2)
store float %med3, float addrspace(1)* %outgep
ret void
}

; ---------------------------------------------------------------------
; Negative patterns
; ---------------------------------------------------------------------
Expand Down

0 comments on commit bcfdaa9

Please sign in to comment.