Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[RISCV] Add support for fixed vector vselect
This patch adds support for fixed-length vector vselect. It does so by
lowering them to a custom unmasked VSELECT_VL node with a vector length
operand.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D96768
  • Loading branch information
frasercrmck committed Feb 17, 2021
1 parent f0d8e73 commit d811616
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 1 deletion.
32 changes: 32 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -558,6 +558,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SMAX, VT, Custom);
setOperationAction(ISD::UMIN, VT, Custom);
setOperationAction(ISD::UMAX, VT, Custom);

setOperationAction(ISD::VSELECT, VT, Custom);
}

for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
Expand Down Expand Up @@ -587,6 +589,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

for (auto CC : VFPCCToExpand)
setCondCodeAction(CC, VT, Expand);

setOperationAction(ISD::VSELECT, VT, Custom);
}
}
}
Expand Down Expand Up @@ -1258,6 +1262,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL);
case ISD::UMAX:
return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL);
case ISD::VSELECT:
return lowerFixedLengthVectorSelectToRVV(Op, DAG);
}
}

Expand Down Expand Up @@ -2247,6 +2253,31 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
return lowerToScalableOp(Op, DAG, VecOpc, /*HasMask*/ true);
}

SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);

MVT I1ContainerVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());

SDValue CC =
convertToScalableVector(I1ContainerVT, Op.getOperand(0), DAG, Subtarget);
SDValue Op1 =
convertToScalableVector(ContainerVT, Op.getOperand(1), DAG, Subtarget);
SDValue Op2 =
convertToScalableVector(ContainerVT, Op.getOperand(2), DAG, Subtarget);

SDLoc DL(Op);
SDValue Mask, VL;
std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);

SDValue Select =
DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC, Op1, Op2, VL);

return convertFromScalableVector(VT, Select, DAG, Subtarget);
}

SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
unsigned NewOpc,
bool HasMask) const {
Expand Down Expand Up @@ -4906,6 +4937,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UMIN_VL)
NODE_NAME_CASE(UMAX_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
NODE_NAME_CASE(VMAND_VL)
NODE_NAME_CASE(VMOR_VL)
NODE_NAME_CASE(VMXOR_VL)
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -175,6 +175,9 @@ enum NodeType : unsigned {
// operand is VL.
SETCC_VL,

// Vector select with an additional VL operand. This operation is unmasked.
VSELECT_VL,

// Mask binary operators.
VMAND_VL,
VMOR_VL,
Expand Down Expand Up @@ -410,6 +413,8 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG,
unsigned MaskOpc,
unsigned VecOpc) const;
SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op,
SelectionDAG &DAG) const;
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc,
bool HasMask = true) const;

Expand Down
63 changes: 62 additions & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Expand Up @@ -111,6 +111,15 @@ def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL",
SDTCisSameNumEltsAs<0, 3>,
SDTCisVT<4, XLenVT>]>>;

def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL",
SDTypeProfile<1, 4, [SDTCisVec<0>,
SDTCisVec<1>,
SDTCisSameNumEltsAs<0, 1>,
SDTCVecEltisVT<1, i1>,
SDTCisSameAs<0, 2>,
SDTCisSameAs<2, 3>,
SDTCisVT<4, XLenVT>]>>;

def SDT_RISCVMaskBinOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
SDTCVecEltisVT<0, i1>,
Expand Down Expand Up @@ -441,6 +450,31 @@ defm "" : VPatBinaryVL_VV_VX<riscv_sdiv_vl, "PseudoVDIV">;
defm "" : VPatBinaryVL_VV_VX<riscv_urem_vl, "PseudoVREMU">;
defm "" : VPatBinaryVL_VV_VX<riscv_srem_vl, "PseudoVREM">;

// 12.16. Vector Integer Merge Instructions
foreach vti = AllIntegerVectors in {
def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask VMV0:$vm),
vti.RegClass:$rs1,
vti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
vti.RegClass:$rs2, vti.RegClass:$rs1, VMV0:$vm,
GPR:$vl, vti.SEW)>;

def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask VMV0:$vm),
(SplatPat XLenVT:$rs1),
vti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
vti.RegClass:$rs2, GPR:$rs1, VMV0:$vm, GPR:$vl, vti.SEW)>;

def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask VMV0:$vm),
(SplatPat_simm5 simm5:$rs1),
vti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
vti.RegClass:$rs2, simm5:$rs1, VMV0:$vm, GPR:$vl, vti.SEW)>;
}

// 12.17. Vector Integer Move Instructions
foreach vti = AllIntegerVectors in {
def : Pat<(vti.Vector (riscv_vmv_v_x_vl GPR:$rs2, (XLenVT (VLOp GPR:$vl)))),
Expand Down Expand Up @@ -609,8 +643,35 @@ foreach vti = AllFloatVectors in {
vti.RegClass:$rs, vti.RegClass:$rs, GPR:$vl, vti.SEW)>;
}

// 14.16. Vector Floating-Point Move Instruction
foreach fvti = AllFloatVectors in {
// Floating-point vselects:
// 12.16. Vector Integer Merge Instructions
// 14.13. Vector Floating-Point Merge Instruction
def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask VMV0:$vm),
fvti.RegClass:$rs1,
fvti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
fvti.RegClass:$rs2, fvti.RegClass:$rs1, VMV0:$vm,
GPR:$vl, fvti.SEW)>;

def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask VMV0:$vm),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
fvti.RegClass:$rs2,
(fvti.Scalar fvti.ScalarRegClass:$rs1),
VMV0:$vm, GPR:$vl, fvti.SEW)>;

def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask VMV0:$vm),
(SplatFPOp (fvti.Scalar fpimm0)),
fvti.RegClass:$rs2,
(XLenVT (VLOp GPR:$vl)))),
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
fvti.RegClass:$rs2, 0, VMV0:$vm, GPR:$vl, fvti.SEW)>;

// 14.16. Vector Floating-Point Move Instruction
// If we're splatting fpimm0, use vmv.v.x vd, x0.
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl
(fvti.Scalar (fpimm0)), (XLenVT (VLOp GPR:$vl)))),
Expand Down

0 comments on commit d811616

Please sign in to comment.