Skip to content

Commit

Permalink
[Hexagon] Introduce PS_vsplat[ir][bhw] pseudo instructions
Browse files Browse the repository at this point in the history
HVX v60 only has splats that take a 32-bit word as input, while v62+
has splats that take 8- or 16-bit value. This makes writing output
patterns that need to use a splat annoying, because the entire output
pattern needs to be replicated for various versions of HVX.
To avoid this, the patterns will always use the pseudos, and then the
pseudos will be handled using a post-ISel hook.
  • Loading branch information
Krzysztof Parzyszek committed Oct 14, 2022
1 parent 3a33c14 commit 7f4ce3f
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 149 deletions.
5 changes: 5 additions & 0 deletions llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
Expand Up @@ -3683,6 +3683,11 @@ bool HexagonTargetLowering::shouldReduceLoadWidth(SDNode *Load,
return true;
}

void HexagonTargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const {
AdjustHvxInstrPostInstrSelection(MI, Node);
}

Value *HexagonTargetLowering::emitLoadLinked(IRBuilderBase &Builder,
Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/Hexagon/HexagonISelLowering.h
Expand Up @@ -332,6 +332,9 @@ class HexagonTargetLowering : public TargetLowering {
bool shouldReduceLoadWidth(SDNode *Load, ISD::LoadExtType ExtTy,
EVT NewVT) const override;

void AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const override;

// Handling of atomic RMW instructions.
Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const override;
Expand Down Expand Up @@ -433,6 +436,7 @@ class HexagonTargetLowering : public TargetLowering {
bool allowsHvxMisalignedMemoryAccesses(MVT VecTy,
MachineMemOperand::Flags Flags,
bool *Fast) const;
void AdjustHvxInstrPostInstrSelection(MachineInstr &MI, SDNode *Node) const;

bool isHvxSingleTy(MVT Ty) const;
bool isHvxPairTy(MVT Ty) const;
Expand Down
116 changes: 116 additions & 0 deletions llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
Expand Up @@ -10,6 +10,12 @@
#include "HexagonRegisterInfo.h"
#include "HexagonSubtarget.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/IntrinsicsHexagon.h"
#include "llvm/Support/CommandLine.h"

Expand Down Expand Up @@ -564,6 +570,116 @@ bool HexagonTargetLowering::allowsHvxMisalignedMemoryAccesses(
return true;
}

void HexagonTargetLowering::AdjustHvxInstrPostInstrSelection(
MachineInstr &MI, SDNode *Node) const {
unsigned Opc = MI.getOpcode();
const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
MachineBasicBlock &MB = *MI.getParent();
MachineFunction &MF = *MB.getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
DebugLoc DL = MI.getDebugLoc();
auto At = MI.getIterator();

switch (Opc) {
case Hexagon::PS_vsplatib:
if (Subtarget.useHVXV62Ops()) {
// SplatV = A2_tfrsi #imm
// OutV = V6_lvsplatb SplatV
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
BuildMI(MB, At, DL, TII.get(Hexagon::A2_tfrsi), SplatV)
.add(MI.getOperand(1));
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatb), OutV)
.addReg(SplatV);
} else {
// SplatV = A2_tfrsi #imm:#imm:#imm:#imm
// OutV = V6_lvsplatw SplatV
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
const MachineOperand &InpOp = MI.getOperand(1);
assert(InpOp.isImm());
uint32_t V = InpOp.getImm() & 0xFF;
BuildMI(MB, At, DL, TII.get(Hexagon::A2_tfrsi), SplatV)
.addImm(V << 24 | V << 16 | V << 8 | V);
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatw), OutV).addReg(SplatV);
}
MB.erase(At);
break;
case Hexagon::PS_vsplatrb:
if (Subtarget.useHVXV62Ops()) {
// OutV = V6_lvsplatb Inp
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatb), OutV)
.add(MI.getOperand(1));
} else {
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
const MachineOperand &InpOp = MI.getOperand(1);
BuildMI(MB, At, DL, TII.get(Hexagon::S2_vsplatrb), SplatV)
.addReg(InpOp.getReg(), 0, InpOp.getSubReg());
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatw), OutV)
.addReg(SplatV);
}
MB.erase(At);
break;
case Hexagon::PS_vsplatih:
if (Subtarget.useHVXV62Ops()) {
// SplatV = A2_tfrsi #imm
// OutV = V6_lvsplath SplatV
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
BuildMI(MB, At, DL, TII.get(Hexagon::A2_tfrsi), SplatV)
.add(MI.getOperand(1));
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplath), OutV)
.addReg(SplatV);
} else {
// SplatV = A2_tfrsi #imm:#imm
// OutV = V6_lvsplatw SplatV
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
const MachineOperand &InpOp = MI.getOperand(1);
assert(InpOp.isImm());
uint32_t V = InpOp.getImm() & 0xFFFF;
BuildMI(MB, At, DL, TII.get(Hexagon::A2_tfrsi), SplatV)
.addImm(V << 16 | V);
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatw), OutV).addReg(SplatV);
}
MB.erase(At);
break;
case Hexagon::PS_vsplatrh:
if (Subtarget.useHVXV62Ops()) {
// OutV = V6_lvsplath Inp
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplath), OutV)
.add(MI.getOperand(1));
} else {
// SplatV = A2_combine_ll Inp, Inp
// OutV = V6_lvsplatw SplatV
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
const MachineOperand &InpOp = MI.getOperand(1);
BuildMI(MB, At, DL, TII.get(Hexagon::A2_combine_ll), SplatV)
.addReg(InpOp.getReg(), 0, InpOp.getSubReg())
.addReg(InpOp.getReg(), 0, InpOp.getSubReg());
Register OutV = MI.getOperand(0).getReg();
BuildMI(MB, At, DL, TII.get(Hexagon::V6_lvsplatw), OutV).addReg(SplatV);
}
MB.erase(At);
break;
case Hexagon::PS_vsplatiw:
case Hexagon::PS_vsplatrw:
if (Opc == Hexagon::PS_vsplatiw) {
// SplatV = A2_tfrsi #imm
Register SplatV = MRI.createVirtualRegister(&Hexagon::IntRegsRegClass);
BuildMI(MB, At, DL, TII.get(Hexagon::A2_tfrsi), SplatV)
.add(MI.getOperand(1));
MI.getOperand(1).ChangeToRegister(SplatV, false);
}
// OutV = V6_lvsplatw SplatV/Inp
MI.setDesc(TII.get(Hexagon::V6_lvsplatw));
break;
}
}

SDValue
HexagonTargetLowering::convertToByteIndex(SDValue ElemIdx, MVT ElemTy,
SelectionDAG &DAG) const {
Expand Down
100 changes: 24 additions & 76 deletions llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
Expand Up @@ -317,72 +317,34 @@ let Predicates = [UseHVX, UseHVXFloatingPoint] in {
(V6_vinsertwr HvxVR:$Vu, I32:$Rt)>;
}

// Splats for HvxV60
def V60splatib: OutPatFrag<(ops node:$V), (V6_lvsplatw (ToI32 (SplatB $V)))>;
def V60splatih: OutPatFrag<(ops node:$V), (V6_lvsplatw (ToI32 (SplatH $V)))>;
def V60splatiw: OutPatFrag<(ops node:$V), (V6_lvsplatw (ToI32 $V))>;
def V60splatrb: OutPatFrag<(ops node:$Rs), (V6_lvsplatw (S2_vsplatrb $Rs))>;
def V60splatrh: OutPatFrag<(ops node:$Rs),
(V6_lvsplatw (A2_combine_ll $Rs, $Rs))>;
def V60splatrw: OutPatFrag<(ops node:$Rs), (V6_lvsplatw $Rs)>;

// Splats for HvxV62+
def V62splatib: OutPatFrag<(ops node:$V), (V6_lvsplatb (ToI32 $V))>;
def V62splatih: OutPatFrag<(ops node:$V), (V6_lvsplath (ToI32 $V))>;
def V62splatiw: OutPatFrag<(ops node:$V), (V6_lvsplatw (ToI32 $V))>;
def V62splatrb: OutPatFrag<(ops node:$Rs), (V6_lvsplatb $Rs)>;
def V62splatrh: OutPatFrag<(ops node:$Rs), (V6_lvsplath $Rs)>;
def V62splatrw: OutPatFrag<(ops node:$Rs), (V6_lvsplatw $Rs)>;

def Rep: OutPatFrag<(ops node:$N), (Combinev $N, $N)>;

let Predicates = [UseHVX,UseHVXV60] in {
let Predicates = [UseHVX] in {
let AddedComplexity = 10 in {
def: Pat<(VecI8 (splat_vector u8_0ImmPred:$V)), (V60splatib $V)>;
def: Pat<(VecI16 (splat_vector u16_0ImmPred:$V)), (V60splatih $V)>;
def: Pat<(VecI32 (splat_vector anyimm:$V)), (V60splatiw $V)>;
def: Pat<(VecPI8 (splat_vector u8_0ImmPred:$V)), (Rep (V60splatib $V))>;
def: Pat<(VecPI16 (splat_vector u16_0ImmPred:$V)), (Rep (V60splatih $V))>;
def: Pat<(VecPI32 (splat_vector anyimm:$V)), (Rep (V60splatiw $V))>;
}
def: Pat<(VecI8 (splat_vector I32:$Rs)), (V60splatrb $Rs)>;
def: Pat<(VecI16 (splat_vector I32:$Rs)), (V60splatrh $Rs)>;
def: Pat<(VecI32 (splat_vector I32:$Rs)), (V60splatrw $Rs)>;
def: Pat<(VecPI8 (splat_vector I32:$Rs)), (Rep (V60splatrb $Rs))>;
def: Pat<(VecPI16 (splat_vector I32:$Rs)), (Rep (V60splatrh $Rs))>;
def: Pat<(VecPI32 (splat_vector I32:$Rs)), (Rep (V60splatrw $Rs))>;
}
let Predicates = [UseHVX,UseHVXV62] in {
let AddedComplexity = 30 in {
def: Pat<(VecI8 (splat_vector u8_0ImmPred:$V)), (V62splatib imm:$V)>;
def: Pat<(VecI16 (splat_vector u16_0ImmPred:$V)), (V62splatih imm:$V)>;
def: Pat<(VecI32 (splat_vector anyimm:$V)), (V62splatiw imm:$V)>;
def: Pat<(VecPI8 (splat_vector u8_0ImmPred:$V)),
(Rep (V62splatib imm:$V))>;
def: Pat<(VecPI16 (splat_vector u16_0ImmPred:$V)),
(Rep (V62splatih imm:$V))>;
def: Pat<(VecPI32 (splat_vector anyimm:$V)),
(Rep (V62splatiw imm:$V))>;
}
let AddedComplexity = 20 in {
def: Pat<(VecI8 (splat_vector I32:$Rs)), (V62splatrb $Rs)>;
def: Pat<(VecI16 (splat_vector I32:$Rs)), (V62splatrh $Rs)>;
def: Pat<(VecI32 (splat_vector I32:$Rs)), (V62splatrw $Rs)>;
def: Pat<(VecPI8 (splat_vector I32:$Rs)), (Rep (V62splatrb $Rs))>;
def: Pat<(VecPI16 (splat_vector I32:$Rs)), (Rep (V62splatrh $Rs))>;
def: Pat<(VecPI32 (splat_vector I32:$Rs)), (Rep (V62splatrw $Rs))>;
def: Pat<(VecI8 (splat_vector u8_0ImmPred:$V)), (PS_vsplatib imm:$V)>;
def: Pat<(VecI16 (splat_vector u16_0ImmPred:$V)), (PS_vsplatih imm:$V)>;
def: Pat<(VecI32 (splat_vector anyimm:$V)), (PS_vsplatiw imm:$V)>;
def: Pat<(VecPI8 (splat_vector u8_0ImmPred:$V)), (Rep (PS_vsplatib imm:$V))>;
def: Pat<(VecPI16 (splat_vector u16_0ImmPred:$V)), (Rep (PS_vsplatih imm:$V))>;
def: Pat<(VecPI32 (splat_vector anyimm:$V)), (Rep (PS_vsplatiw imm:$V))>;
}
def: Pat<(VecI8 (splat_vector I32:$Rs)), (PS_vsplatrb $Rs)>;
def: Pat<(VecI16 (splat_vector I32:$Rs)), (PS_vsplatrh $Rs)>;
def: Pat<(VecI32 (splat_vector I32:$Rs)), (PS_vsplatrw $Rs)>;
def: Pat<(VecPI8 (splat_vector I32:$Rs)), (Rep (PS_vsplatrb $Rs))>;
def: Pat<(VecPI16 (splat_vector I32:$Rs)), (Rep (PS_vsplatrh $Rs))>;
def: Pat<(VecPI32 (splat_vector I32:$Rs)), (Rep (PS_vsplatrw $Rs))>;
}
let Predicates = [UseHVXV68, UseHVXFloatingPoint] in {
let AddedComplexity = 30 in {
def: Pat<(VecF16 (splat_vector u16_0ImmPred:$V)), (V62splatih imm:$V)>;
def: Pat<(VecF32 (splat_vector anyint:$V)), (V62splatiw imm:$V)>;
def: Pat<(VecF32 (splat_vector f32ImmPred:$V)), (V62splatiw (ftoi $V))>;
def: Pat<(VecF16 (splat_vector u16_0ImmPred:$V)), (PS_vsplatih imm:$V)>;
def: Pat<(VecF32 (splat_vector anyint:$V)), (PS_vsplatiw imm:$V)>;
def: Pat<(VecF32 (splat_vector f32ImmPred:$V)), (PS_vsplatiw (ftoi $V))>;
}
let AddedComplexity = 20 in {
def: Pat<(VecF16 (splat_vector I32:$Rs)), (V62splatrh $Rs)>;
def: Pat<(VecF32 (splat_vector I32:$Rs)), (V62splatrw $Rs)>;
def: Pat<(VecF32 (splat_vector F32:$Rs)), (V62splatrw $Rs)>;
def: Pat<(VecF16 (splat_vector I32:$Rs)), (PS_vsplatrh $Rs)>;
def: Pat<(VecF32 (splat_vector I32:$Rs)), (PS_vsplatrw $Rs)>;
def: Pat<(VecF32 (splat_vector F32:$Rs)), (PS_vsplatrw $Rs)>;
}
}

Expand Down Expand Up @@ -672,18 +634,10 @@ let Predicates = [UseHVX] in {
def: Pat<(srl HVI16:$Vs, HVI16:$Vt), (V6_vlsrhv HvxVR:$Vs, HvxVR:$Vt)>;
def: Pat<(srl HVI32:$Vs, HVI32:$Vt), (V6_vlsrwv HvxVR:$Vs, HvxVR:$Vt)>;

let Predicates = [UseHVX,UseHVXV60] in {
def: Pat<(VecI16 (bswap HVI16:$Vs)),
(V6_vdelta HvxVR:$Vs, (V60splatib (i32 0x01)))>;
def: Pat<(VecI32 (bswap HVI32:$Vs)),
(V6_vdelta HvxVR:$Vs, (V60splatib (i32 0x03)))>;
}
let Predicates = [UseHVX,UseHVXV62], AddedComplexity = 10 in {
def: Pat<(VecI16 (bswap HVI16:$Vs)),
(V6_vdelta HvxVR:$Vs, (V62splatib (i32 0x01)))>;
def: Pat<(VecI32 (bswap HVI32:$Vs)),
(V6_vdelta HvxVR:$Vs, (V62splatib (i32 0x03)))>;
}
def: Pat<(VecI16 (bswap HVI16:$Vs)),
(V6_vdelta HvxVR:$Vs, (PS_vsplatib (i32 0x01)))>;
def: Pat<(VecI32 (bswap HVI32:$Vs)),
(V6_vdelta HvxVR:$Vs, (PS_vsplatib (i32 0x03)))>;

def: Pat<(VecI8 (ctpop HVI8:$Vs)),
(V6_vshuffeb (V6_vpopcounth (HiVec (V6_vzb HvxVR:$Vs))),
Expand All @@ -693,16 +647,10 @@ let Predicates = [UseHVX] in {
(V6_vaddw (LoVec (V6_vzh (V6_vpopcounth HvxVR:$Vs))),
(HiVec (V6_vzh (V6_vpopcounth HvxVR:$Vs))))>;

let Predicates = [UseHVX,UseHVXV60] in
def: Pat<(VecI8 (ctlz HVI8:$Vs)),
(V6_vsubb (V6_vshuffeb (V6_vcl0h (HiVec (V6_vzb HvxVR:$Vs))),
(V6_vcl0h (LoVec (V6_vzb HvxVR:$Vs)))),
(V60splatib (i32 0x08)))>;
let Predicates = [UseHVX,UseHVXV62], AddedComplexity = 10 in
def: Pat<(VecI8 (ctlz HVI8:$Vs)),
(V6_vsubb (V6_vshuffeb (V6_vcl0h (HiVec (V6_vzb HvxVR:$Vs))),
(V6_vcl0h (LoVec (V6_vzb HvxVR:$Vs)))),
(V62splatib (i32 0x08)))>;
(PS_vsplatib (i32 0x08)))>;

def: Pat<(VecI16 (ctlz HVI16:$Vs)), (V6_vcl0h HvxVR:$Vs)>;
def: Pat<(VecI32 (ctlz HVI32:$Vs)), (V6_vcl0w HvxVR:$Vs)>;
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/Hexagon/HexagonPseudo.td
Expand Up @@ -427,6 +427,22 @@ let isCall = 1, Uses = [R29, R31], isAsmParserOnly = 1 in {
def SAVE_REGISTERS_CALL_V4STK_EXT_PIC : T_Call<"">, PredRel;
}

let Predicates = [UseHVX], isPseudo = 1, isCodeGenOnly = 1,
hasSideEffects = 0, hasPostISelHook = 1 in
class Vsplatr_template : InstHexagon<(outs HvxVR:$Vd), (ins IntRegs:$Rs),
"", [], "", V6_lvsplatw.Itinerary, V6_lvsplatw.Type>;
def PS_vsplatrb: Vsplatr_template;
def PS_vsplatrh: Vsplatr_template;
def PS_vsplatrw: Vsplatr_template;

let Predicates = [UseHVX], isPseudo = 1, isCodeGenOnly = 1,
hasSideEffects = 0, hasPostISelHook = 1 in
class Vsplati_template : InstHexagon<(outs HvxVR:$Vd), (ins s32_0Imm:$Val),
"", [], "", V6_lvsplatw.Itinerary, V6_lvsplatw.Type>;
def PS_vsplatib: Vsplati_template;
def PS_vsplatih: Vsplati_template;
def PS_vsplatiw: Vsplati_template;

// Vector store pseudos
let Predicates = [HasV60,UseHVX], isPseudo = 1, isCodeGenOnly = 1,
mayStore = 1, accessSize = HVXVectorAccess, hasSideEffects = 0 in
Expand Down

0 comments on commit 7f4ce3f

Please sign in to comment.