Skip to content

Commit

Permalink
[AMDGPU] Fix predicates on FLAT scratch ST/SVS mode Pseudos (#85442)
Browse files Browse the repository at this point in the history
Definitions like this did not work as intended:

  let is_flat_scratch = 1 in {
    let SubtargetPredicate = HasFlatScratchSVSMode in
def _SVS : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 1,
1>,
               FlatScratchInst<opName, "SVS">;

    let SubtargetPredicate = HasFlatScratchSTMode in
def _ST : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 0,
0, 0>,
               FlatScratchInst<opName, "ST">;
  }

They tried to override SubtargetPredicate, but then it was overridden
again (back to its default value) by setting is_flat_scratch, which
caused SubtargetPredicate to be recalculated in the base class. (This
patch also removes some overrides of SubtargetPredicate that are
redundant due to being recalculated in the base class.)

Fix this by pushing overrides of is_flat_scratch and is_flat_global "in"
as far as possible. This has the added benefit that there is no need to
override them around groups of Pseudo definitions like this:

let is_flat_global = 1 in {
defm GLOBAL_ATOMIC_CMPSWAP : FLAT_Global_Atomic_Pseudo
<"global_atomic_cmpswap",
                               VGPR_32, i32, v2i32, VReg_64>;
...
}

which are plainly Global instructions anyway.

Verified by inspecting the output of TableGen. It seems to be NFC in
practice.
  • Loading branch information
jayfoad committed Mar 17, 2024
1 parent f3c5278 commit bfb8682
Showing 1 changed file with 69 additions and 77 deletions.
146 changes: 69 additions & 77 deletions llvm/lib/Target/AMDGPU/FLATInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class FLAT_Store_Pseudo <string opName, RegisterClass vdataClass,
}

multiclass FLAT_Global_Load_Pseudo<string opName, RegisterClass regClass, bit HasTiedInput = 0> {
let is_flat_global = 1, SubtargetPredicate = HasFlatGlobalInsts in {
let is_flat_global = 1 in {
def "" : FLAT_Load_Pseudo<opName, regClass, HasTiedInput, 1>,
GlobalSaddrTable<0, opName>;
def _SADDR : FLAT_Load_Pseudo<opName, regClass, HasTiedInput, 1, 1>,
Expand Down Expand Up @@ -276,7 +276,7 @@ multiclass FLAT_Global_Load_AddTid_Pseudo<string opName, RegisterClass regClass,
}

multiclass FLAT_Global_Store_Pseudo<string opName, RegisterClass regClass> {
let is_flat_global = 1, SubtargetPredicate = HasFlatGlobalInsts in {
let is_flat_global = 1 in {
def "" : FLAT_Store_Pseudo<opName, regClass, 1>,
GlobalSaddrTable<0, opName>;
def _SADDR : FLAT_Store_Pseudo<opName, regClass, 1, 1>,
Expand Down Expand Up @@ -389,6 +389,7 @@ class FLAT_Scratch_Load_Pseudo <string opName, RegisterClass regClass,
!if(HasTiedOutput, (ins CPol:$cpol, getLdStRegisterOperand<regClass>.ret:$vdst_in),
(ins CPol_0:$cpol))),
" $vdst, "#!if(EnableVaddr, "$vaddr, ", "off, ")#!if(EnableSaddr, "$saddr", "off")#"$offset$cpol"> {
let is_flat_scratch = 1;
let has_data = 0;
let mayLoad = 1;
let has_saddr = 1;
Expand Down Expand Up @@ -416,6 +417,7 @@ class FLAT_Scratch_Store_Pseudo <string opName, RegisterClass vdataClass, bit En
(ins vdata_op:$vdata, VGPR_32:$vaddr, flat_offset:$offset, CPol_0:$cpol),
(ins vdata_op:$vdata, flat_offset:$offset, CPol_0:$cpol)))),
" "#!if(EnableVaddr, "$vaddr", "off")#", $vdata, "#!if(EnableSaddr, "$saddr", "off")#"$offset$cpol"> {
let is_flat_scratch = 1;
let mayLoad = 0;
let mayStore = 1;
let has_vdst = 0;
Expand All @@ -428,37 +430,33 @@ class FLAT_Scratch_Store_Pseudo <string opName, RegisterClass vdataClass, bit En
}

multiclass FLAT_Scratch_Load_Pseudo<string opName, RegisterClass regClass, bit HasTiedOutput = 0> {
let is_flat_scratch = 1 in {
def "" : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput>,
FlatScratchInst<opName, "SV">;
def _SADDR : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 1>,
FlatScratchInst<opName, "SS">;

let SubtargetPredicate = HasFlatScratchSVSMode in
def _SVS : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 1, 1>,
FlatScratchInst<opName, "SVS">;
def "" : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput>,
FlatScratchInst<opName, "SV">;
def _SADDR : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 1>,
FlatScratchInst<opName, "SS">;

let SubtargetPredicate = HasFlatScratchSTMode in
def _ST : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 0, 0, 0>,
FlatScratchInst<opName, "ST">;
}
let SubtargetPredicate = HasFlatScratchSVSMode in
def _SVS : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 1, 1>,
FlatScratchInst<opName, "SVS">;

let SubtargetPredicate = HasFlatScratchSTMode in
def _ST : FLAT_Scratch_Load_Pseudo<opName, regClass, HasTiedOutput, 0, 0, 0>,
FlatScratchInst<opName, "ST">;
}

multiclass FLAT_Scratch_Store_Pseudo<string opName, RegisterClass regClass> {
let is_flat_scratch = 1 in {
def "" : FLAT_Scratch_Store_Pseudo<opName, regClass>,
FlatScratchInst<opName, "SV">;
def _SADDR : FLAT_Scratch_Store_Pseudo<opName, regClass, 1>,
FlatScratchInst<opName, "SS">;

let SubtargetPredicate = HasFlatScratchSVSMode in
def _SVS : FLAT_Scratch_Store_Pseudo<opName, regClass, 1, 1>,
FlatScratchInst<opName, "SVS">;
def "" : FLAT_Scratch_Store_Pseudo<opName, regClass>,
FlatScratchInst<opName, "SV">;
def _SADDR : FLAT_Scratch_Store_Pseudo<opName, regClass, 1>,
FlatScratchInst<opName, "SS">;

let SubtargetPredicate = HasFlatScratchSTMode in
def _ST : FLAT_Scratch_Store_Pseudo<opName, regClass, 0, 0, 0>,
FlatScratchInst<opName, "ST">;
}
let SubtargetPredicate = HasFlatScratchSVSMode in
def _SVS : FLAT_Scratch_Store_Pseudo<opName, regClass, 1, 1>,
FlatScratchInst<opName, "SVS">;

let SubtargetPredicate = HasFlatScratchSTMode in
def _ST : FLAT_Scratch_Store_Pseudo<opName, regClass, 0, 0, 0>,
FlatScratchInst<opName, "ST">;
}

class FLAT_Scratch_Load_LDS_Pseudo <string opName, bit EnableSaddr = 0,
Expand Down Expand Up @@ -583,25 +581,27 @@ multiclass FLAT_Global_Atomic_Pseudo_NO_RTN<
RegisterClass data_rc = vdst_rc,
RegisterOperand data_op = getLdStRegisterOperand<data_rc>.ret> {

def "" : FLAT_AtomicNoRet_Pseudo <opName,
(outs),
(ins VReg_64:$vaddr, data_op:$vdata, flat_offset:$offset, CPol_0:$cpol),
" $vaddr, $vdata, off$offset$cpol">,
GlobalSaddrTable<0, opName> {
let has_saddr = 1;
let PseudoInstr = NAME;
let FPAtomic = data_vt.isFP;
}
let is_flat_global = 1 in {
def "" : FLAT_AtomicNoRet_Pseudo <opName,
(outs),
(ins VReg_64:$vaddr, data_op:$vdata, flat_offset:$offset, CPol_0:$cpol),
" $vaddr, $vdata, off$offset$cpol">,
GlobalSaddrTable<0, opName> {
let has_saddr = 1;
let PseudoInstr = NAME;
let FPAtomic = data_vt.isFP;
}

def _SADDR : FLAT_AtomicNoRet_Pseudo <opName,
(outs),
(ins VGPR_32:$vaddr, data_op:$vdata, SReg_64:$saddr, flat_offset:$offset, CPol_0:$cpol),
" $vaddr, $vdata, $saddr$offset$cpol">,
GlobalSaddrTable<1, opName> {
let has_saddr = 1;
let enabled_saddr = 1;
let PseudoInstr = NAME#"_SADDR";
let FPAtomic = data_vt.isFP;
def _SADDR : FLAT_AtomicNoRet_Pseudo <opName,
(outs),
(ins VGPR_32:$vaddr, data_op:$vdata, SReg_64:$saddr, flat_offset:$offset, CPol_0:$cpol),
" $vaddr, $vdata, $saddr$offset$cpol">,
GlobalSaddrTable<1, opName> {
let has_saddr = 1;
let enabled_saddr = 1;
let PseudoInstr = NAME#"_SADDR";
let FPAtomic = data_vt.isFP;
}
}
}

Expand All @@ -614,24 +614,26 @@ multiclass FLAT_Global_Atomic_Pseudo_RTN<
RegisterOperand data_op = getLdStRegisterOperand<data_rc>.ret,
RegisterOperand vdst_op = getLdStRegisterOperand<vdst_rc>.ret> {

def _RTN : FLAT_AtomicRet_Pseudo <opName,
(outs vdst_op:$vdst),
(ins VReg_64:$vaddr, data_op:$vdata, flat_offset:$offset, CPol_GLC1:$cpol),
" $vdst, $vaddr, $vdata, off$offset$cpol">,
GlobalSaddrTable<0, opName#"_rtn"> {
let has_saddr = 1;
let FPAtomic = data_vt.isFP;
}
let is_flat_global = 1 in {
def _RTN : FLAT_AtomicRet_Pseudo <opName,
(outs vdst_op:$vdst),
(ins VReg_64:$vaddr, data_op:$vdata, flat_offset:$offset, CPol_GLC1:$cpol),
" $vdst, $vaddr, $vdata, off$offset$cpol">,
GlobalSaddrTable<0, opName#"_rtn"> {
let has_saddr = 1;
let FPAtomic = data_vt.isFP;
}

def _SADDR_RTN : FLAT_AtomicRet_Pseudo <opName,
(outs vdst_op:$vdst),
(ins VGPR_32:$vaddr, data_op:$vdata, SReg_64:$saddr, flat_offset:$offset, CPol_GLC1:$cpol),
" $vdst, $vaddr, $vdata, $saddr$offset$cpol">,
GlobalSaddrTable<1, opName#"_rtn"> {
let has_saddr = 1;
let enabled_saddr = 1;
let PseudoInstr = NAME#"_SADDR_RTN";
let FPAtomic = data_vt.isFP;
def _SADDR_RTN : FLAT_AtomicRet_Pseudo <opName,
(outs vdst_op:$vdst),
(ins VGPR_32:$vaddr, data_op:$vdata, SReg_64:$saddr, flat_offset:$offset, CPol_GLC1:$cpol),
" $vdst, $vaddr, $vdata, $saddr$offset$cpol">,
GlobalSaddrTable<1, opName#"_rtn"> {
let has_saddr = 1;
let enabled_saddr = 1;
let PseudoInstr = NAME#"_SADDR_RTN";
let FPAtomic = data_vt.isFP;
}
}
}

Expand All @@ -641,10 +643,8 @@ multiclass FLAT_Global_Atomic_Pseudo<
ValueType vt,
ValueType data_vt = vt,
RegisterClass data_rc = vdst_rc> {
let is_flat_global = 1, SubtargetPredicate = HasFlatGlobalInsts in {
defm "" : FLAT_Global_Atomic_Pseudo_NO_RTN<opName, vdst_rc, vt, data_vt, data_rc>;
defm "" : FLAT_Global_Atomic_Pseudo_RTN<opName, vdst_rc, vt, data_vt, data_rc>;
}
defm "" : FLAT_Global_Atomic_Pseudo_NO_RTN<opName, vdst_rc, vt, data_vt, data_rc>;
defm "" : FLAT_Global_Atomic_Pseudo_RTN<opName, vdst_rc, vt, data_vt, data_rc>;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -848,7 +848,6 @@ defm GLOBAL_STORE_DWORD_ADDTID : FLAT_Global_Store_AddTid_Pseudo <"global_store_
defm GLOBAL_STORE_BYTE_D16_HI : FLAT_Global_Store_Pseudo <"global_store_byte_d16_hi", VGPR_32>;
defm GLOBAL_STORE_SHORT_D16_HI : FLAT_Global_Store_Pseudo <"global_store_short_d16_hi", VGPR_32>;

let is_flat_global = 1 in {
defm GLOBAL_ATOMIC_CMPSWAP : FLAT_Global_Atomic_Pseudo <"global_atomic_cmpswap",
VGPR_32, i32, v2i32, VReg_64>;

Expand Down Expand Up @@ -947,9 +946,6 @@ let SubtargetPredicate = isGFX12Plus in {
def GLOBAL_WBINV : FLAT_Global_Invalidate_Writeback<"global_wbinv">;
} // End SubtargetPredicate = isGFX12Plus

} // End is_flat_global = 1

let SubtargetPredicate = HasFlatScratchInsts in {
defm SCRATCH_LOAD_UBYTE : FLAT_Scratch_Load_Pseudo <"scratch_load_ubyte", VGPR_32>;
defm SCRATCH_LOAD_SBYTE : FLAT_Scratch_Load_Pseudo <"scratch_load_sbyte", VGPR_32>;
defm SCRATCH_LOAD_USHORT : FLAT_Scratch_Load_Pseudo <"scratch_load_ushort", VGPR_32>;
Expand Down Expand Up @@ -984,8 +980,6 @@ defm SCRATCH_LOAD_LDS_USHORT : FLAT_Scratch_Load_LDS_Pseudo <"scratch_load_lds_u
defm SCRATCH_LOAD_LDS_SSHORT : FLAT_Scratch_Load_LDS_Pseudo <"scratch_load_lds_sshort">;
defm SCRATCH_LOAD_LDS_DWORD : FLAT_Scratch_Load_LDS_Pseudo <"scratch_load_lds_dword">;

} // End SubtargetPredicate = HasFlatScratchInsts

let SubtargetPredicate = isGFX12Plus in {
let WaveSizePredicate = isWave32 in {
defm GLOBAL_LOAD_TR_B128_w32 : FLAT_Global_Load_Pseudo <"global_load_tr_b128_w32", VReg_128>;
Expand All @@ -997,7 +991,7 @@ let SubtargetPredicate = isGFX12Plus in {
}
} // End SubtargetPredicate = isGFX12Plus

let SubtargetPredicate = isGFX10Plus, is_flat_global = 1 in {
let SubtargetPredicate = isGFX10Plus in {
defm GLOBAL_ATOMIC_FCMPSWAP :
FLAT_Global_Atomic_Pseudo<"global_atomic_fcmpswap", VGPR_32, f32, v2f32, VReg_64>;
defm GLOBAL_ATOMIC_FMIN :
Expand All @@ -1010,9 +1004,8 @@ let SubtargetPredicate = isGFX10Plus, is_flat_global = 1 in {
FLAT_Global_Atomic_Pseudo<"global_atomic_fmin_x2", VReg_64, f64>;
defm GLOBAL_ATOMIC_FMAX_X2 :
FLAT_Global_Atomic_Pseudo<"global_atomic_fmax_x2", VReg_64, f64>;
} // End SubtargetPredicate = isGFX10Plus, is_flat_global = 1
} // End SubtargetPredicate = isGFX10Plus

let is_flat_global = 1 in {
let OtherPredicates = [HasAtomicFaddNoRtnInsts] in
defm GLOBAL_ATOMIC_ADD_F32 : FLAT_Global_Atomic_Pseudo_NO_RTN <
"global_atomic_add_f32", VGPR_32, f32
Expand All @@ -1029,7 +1022,6 @@ let OtherPredicates = [HasAtomicBufferGlobalPkAddF16Insts] in
defm GLOBAL_ATOMIC_PK_ADD_F16 : FLAT_Global_Atomic_Pseudo_RTN <
"global_atomic_pk_add_f16", VGPR_32, v2f16
>;
} // End is_flat_global = 1

//===----------------------------------------------------------------------===//
// Flat Patterns
Expand Down

0 comments on commit bfb8682

Please sign in to comment.