Skip to content

Commit

Permalink
[RISCV][VP] Lower VP_MERGE to RVV instructions
Browse files Browse the repository at this point in the history
This patch adds lowering of the llvm.vp.merge.* intrinsic
(ISD::VP_MERGE) to RVV vmerge/vfmerge instructions. It introduces a
special pseudo form of vmerge which allows a tied merge operand,
allowing us to specify the tail elements as being equal to the "on
false" operand, using a tied-def constraint and a "tail undisturbed"
policy.

While this strategy allows us to often lower the intrinsic to just one
instruction, it may be less efficient in fixed-vector types as the
number of tail elements may extend far beyond the length of the fixed
vector. Another strategy could be to use a vmerge/vfmerge instruction
with an AVL equal to the length of the vector type, and manipulate the
condition operand such that mask elements greater than the operation's
EVL are false.

I've also observed inefficient codegen in which our 'VF' patterns don't
match raw floating-point SPLAT_VECTORs, which occur in scalable-vector
code.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D117561
  • Loading branch information
frasercrmck committed Jan 24, 2022
1 parent e7926e8 commit af773a1
Show file tree
Hide file tree
Showing 6 changed files with 2,372 additions and 12 deletions.
8 changes: 6 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -521,12 +521,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_SHL, ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND,
ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR, ISD::VP_REDUCE_SMAX,
ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
ISD::VP_SELECT};
ISD::VP_MERGE, ISD::VP_SELECT};

static const unsigned FloatingPointVPOps[] = {
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
ISD::VP_FDIV, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD,
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_SELECT};
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_MERGE,
ISD::VP_SELECT};

if (!Subtarget.is64Bit()) {
// We must custom-lower certain vXi64 operations on RV32 due to the vector
Expand Down Expand Up @@ -3441,6 +3442,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerSET_ROUNDING(Op, DAG);
case ISD::VP_SELECT:
return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL);
case ISD::VP_MERGE:
return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL);
case ISD::VP_ADD:
return lowerVPOp(Op, DAG, RISCVISD::ADD_VL);
case ISD::VP_SUB:
Expand Down Expand Up @@ -10087,6 +10090,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWADDU_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
NODE_NAME_CASE(VP_MERGE_VL)
NODE_NAME_CASE(VMAND_VL)
NODE_NAME_CASE(VMOR_VL)
NODE_NAME_CASE(VMXOR_VL)
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -253,6 +253,10 @@ enum NodeType : unsigned {

// Vector select with an additional VL operand. This operation is unmasked.
VSELECT_VL,
// Vector select with operand #2 (the value when the condition is false) tied
// to the destination and an additional VL operand. This operation is
// unmasked.
VP_MERGE_VL,

// Mask binary operators.
VMAND_VL,
Expand Down
75 changes: 73 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Expand Up @@ -579,10 +579,11 @@ class PseudoToVInst<string PseudoInst> {
!subst("_B64", "",
!subst("_MASK", "",
!subst("_TIED", "",
!subst("_TU", "",
!subst("F16", "F",
!subst("F32", "F",
!subst("F64", "F",
!subst("Pseudo", "", PseudoInst))))))))))))))))))));
!subst("Pseudo", "", PseudoInst)))))))))))))))))))));
}

// The destination vector register group for a masked vector instruction cannot
Expand Down Expand Up @@ -928,6 +929,9 @@ class VPseudoBinaryNoMask<VReg RetClass,
let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
}

// Special version of VPseudoBinaryNoMask where we pretend the first source is
// tied to the destination.
// This allows maskedoff and rs2 to be the same register.
class VPseudoTiedBinaryNoMask<VReg RetClass,
DAGOperand Op2Class,
string Constraint> :
Expand Down Expand Up @@ -1079,6 +1083,30 @@ class VPseudoBinaryCarryIn<VReg RetClass,
let VLMul = MInfo.value;
}

class VPseudoTiedBinaryCarryIn<VReg RetClass,
VReg Op1Class,
DAGOperand Op2Class,
LMULInfo MInfo,
bit CarryIn,
string Constraint> :
Pseudo<(outs RetClass:$rd),
!if(CarryIn,
(ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, VMV0:$carry, AVL:$vl,
ixlenimm:$sew),
(ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, AVL:$vl, ixlenimm:$sew)), []>,
RISCVVPseudo {
let mayLoad = 0;
let mayStore = 0;
let hasSideEffects = 0;
let Constraints = Join<[Constraint, "$rd = $merge"], ",">.ret;
let HasVLOp = 1;
let HasSEWOp = 1;
let HasMergeOp = 1;
let HasVecPolicyOp = 0;
let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
let VLMul = MInfo.value;
}

class VPseudoTernaryNoMask<VReg RetClass,
RegisterClass Op1Class,
DAGOperand Op2Class,
Expand Down Expand Up @@ -1741,6 +1769,16 @@ multiclass VPseudoBinaryV_VM<bit CarryOut = 0, bit CarryIn = 1,
m.vrclass, m.vrclass, m, CarryIn, Constraint>;
}

multiclass VPseudoTiedBinaryV_VM<bit CarryOut = 0, bit CarryIn = 1,
string Constraint = ""> {
foreach m = MxList in
def "_VV" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU" :
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
!if(!and(CarryIn, !not(CarryOut)),
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
m.vrclass, m.vrclass, m, CarryIn, Constraint>;
}

multiclass VPseudoBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
string Constraint = ""> {
foreach m = MxList in
Expand All @@ -1751,13 +1789,29 @@ multiclass VPseudoBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
m.vrclass, GPR, m, CarryIn, Constraint>;
}

multiclass VPseudoTiedBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
string Constraint = ""> {
foreach m = MxList in
def "_VX" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU":
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
!if(!and(CarryIn, !not(CarryOut)),
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
m.vrclass, GPR, m, CarryIn, Constraint>;
}

multiclass VPseudoVMRG_FM {
foreach f = FPList in
foreach m = f.MxList in
foreach m = f.MxList in {
def "_V" # f.FX # "M_" # m.MX :
VPseudoBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
m.vrclass, f.fprclass, m, /*CarryIn=*/1, "">,
Sched<[WriteVFMergeV, ReadVFMergeV, ReadVFMergeF, ReadVMask]>;
// Tied version to allow codegen control over the tail elements
def "_V" # f.FX # "M_" # m.MX # "_TU":
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
m.vrclass, f.fprclass, m, /*CarryIn=*/1, "">,
Sched<[WriteVFMergeV, ReadVFMergeV, ReadVFMergeF, ReadVMask]>;
}
}

multiclass VPseudoBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
Expand All @@ -1770,6 +1824,16 @@ multiclass VPseudoBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
m.vrclass, simm5, m, CarryIn, Constraint>;
}

multiclass VPseudoTiedBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
string Constraint = ""> {
foreach m = MxList in
def "_VI" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU":
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
!if(!and(CarryIn, !not(CarryOut)),
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
m.vrclass, simm5, m, CarryIn, Constraint>;
}

multiclass VPseudoUnaryVMV_V_X_I {
foreach m = MxList in {
let VLMul = m.value in {
Expand Down Expand Up @@ -2104,6 +2168,13 @@ multiclass VPseudoVMRG_VM_XM_IM {
Sched<[WriteVIMergeX, ReadVIMergeV, ReadVIMergeX, ReadVMask]>;
defm "" : VPseudoBinaryV_IM,
Sched<[WriteVIMergeI, ReadVIMergeV, ReadVMask]>;
// Tied versions to allow codegen control over the tail elements
defm "" : VPseudoTiedBinaryV_VM,
Sched<[WriteVIMergeV, ReadVIMergeV, ReadVIMergeV, ReadVMask]>;
defm "" : VPseudoTiedBinaryV_XM,
Sched<[WriteVIMergeX, ReadVIMergeV, ReadVIMergeX, ReadVMask]>;
defm "" : VPseudoTiedBinaryV_IM,
Sched<[WriteVIMergeI, ReadVIMergeV, ReadVMask]>;
}

multiclass VPseudoVCALU_VM_XM_IM {
Expand Down
64 changes: 56 additions & 8 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Expand Up @@ -177,14 +177,13 @@ def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;

def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL",
SDTypeProfile<1, 4, [SDTCisVec<0>,
SDTCisVec<1>,
SDTCisSameNumEltsAs<0, 1>,
SDTCVecEltisVT<1, i1>,
SDTCisSameAs<0, 2>,
SDTCisSameAs<2, 3>,
SDTCisVT<4, XLenVT>]>>;
def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisVT<4, XLenVT>
]>;

def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;

def SDT_RISCVMaskBinOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
Expand Down Expand Up @@ -976,6 +975,30 @@ foreach vti = AllIntegerVectors in {
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
vti.RegClass:$rs1,
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX#"_TU")
vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
(SplatPat XLenVT:$rs1),
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX#"_TU")
vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
(SplatPat_simm5 simm5:$rs1),
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX#"_TU")
vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
}

// 12.16. Vector Integer Move Instructions
Expand Down Expand Up @@ -1223,6 +1246,31 @@ foreach fvti = AllFloatVectors in {
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
fvti.RegClass:$rs1,
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX#"_TU")
fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX#"_TU")
fvti.RegClass:$rs2, fvti.RegClass:$rs2,
(fvti.Scalar fvti.ScalarRegClass:$rs1),
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
(SplatFPOp (fvti.Scalar fpimm0)),
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX#"_TU")
fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;

// 14.16. Vector Floating-Point Move Instruction
// If we're splatting fpimm0, use vmv.v.x vd, x0.
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
Expand Down

0 comments on commit af773a1

Please sign in to comment.