Skip to content

Commit

Permalink
[RISCV][NFC] Simplify lowerVPOp.
Browse files Browse the repository at this point in the history
This patch is similar to https://reviews.llvm.org/D153948, using helper function to get ISD and information.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D154411
  • Loading branch information
jacquesguan committed Jul 27, 2023
1 parent 0d677c8 commit 5d6d649
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 63 deletions.
142 changes: 85 additions & 57 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4664,10 +4664,13 @@ static unsigned getRISCVVLOp(SDValue Op) {
#define OP_CASE(NODE) \
case ISD::NODE: \
return RISCVISD::NODE##_VL;
#define VP_CASE(NODE) \
case ISD::VP_##NODE: \
return RISCVISD::NODE##_VL;
// clang-format off
switch (Op.getOpcode()) {
default:
llvm_unreachable("don't have RISC-V specified VL op for this SDNode");
// clang-format off
OP_CASE(ADD)
OP_CASE(SUB)
OP_CASE(MUL)
Expand Down Expand Up @@ -4702,25 +4705,81 @@ static unsigned getRISCVVLOp(SDValue Op) {
OP_CASE(STRICT_FMUL)
OP_CASE(STRICT_FDIV)
OP_CASE(STRICT_FSQRT)
// clang-format on
#undef OP_CASE
VP_CASE(ADD) // VP_ADD
VP_CASE(SUB) // VP_SUB
VP_CASE(MUL) // VP_MUL
VP_CASE(SDIV) // VP_SDIV
VP_CASE(SREM) // VP_SREM
VP_CASE(UDIV) // VP_UDIV
VP_CASE(UREM) // VP_UREM
VP_CASE(SHL) // VP_SHL
VP_CASE(FADD) // VP_FADD
VP_CASE(FSUB) // VP_FSUB
VP_CASE(FMUL) // VP_FMUL
VP_CASE(FDIV) // VP_FDIV
VP_CASE(FNEG) // VP_FNEG
VP_CASE(FABS) // VP_FABS
VP_CASE(SMIN) // VP_SMIN
VP_CASE(SMAX) // VP_SMAX
VP_CASE(UMIN) // VP_UMIN
VP_CASE(UMAX) // VP_UMAX
VP_CASE(FMINNUM) // VP_FMINNUM
VP_CASE(FMAXNUM) // VP_FMAXNUM
VP_CASE(FCOPYSIGN) // VP_FCOPYSIGN
VP_CASE(SETCC) // VP_SETCC
VP_CASE(SINT_TO_FP) // VP_SINT_TO_FP
VP_CASE(UINT_TO_FP) // VP_UINT_TO_FP
VP_CASE(BITREVERSE) // VP_BITREVERSE
VP_CASE(BSWAP) // VP_BSWAP
VP_CASE(CTLZ) // VP_CTLZ
VP_CASE(CTTZ) // VP_CTTZ
VP_CASE(CTPOP) // VP_CTPOP
case ISD::VP_CTLZ_ZERO_UNDEF:
return RISCVISD::CTLZ_VL;
case ISD::VP_CTTZ_ZERO_UNDEF:
return RISCVISD::CTTZ_VL;
case ISD::FMA:
case ISD::VP_FMA:
return RISCVISD::VFMADD_VL;
case ISD::STRICT_FMA:
return RISCVISD::STRICT_VFMADD_VL;
case ISD::AND:
case ISD::VP_AND:
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
return RISCVISD::VMAND_VL;
return RISCVISD::AND_VL;
case ISD::OR:
case ISD::VP_OR:
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
return RISCVISD::VMOR_VL;
return RISCVISD::OR_VL;
case ISD::XOR:
case ISD::VP_XOR:
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
return RISCVISD::VMXOR_VL;
return RISCVISD::XOR_VL;
case ISD::VP_SELECT:
return RISCVISD::VSELECT_VL;
case ISD::VP_MERGE:
return RISCVISD::VP_MERGE_VL;
case ISD::VP_ASHR:
return RISCVISD::SRA_VL;
case ISD::VP_LSHR:
return RISCVISD::SRL_VL;
case ISD::VP_SQRT:
return RISCVISD::FSQRT_VL;
case ISD::VP_SIGN_EXTEND:
return RISCVISD::VSEXT_VL;
case ISD::VP_ZERO_EXTEND:
return RISCVISD::VZEXT_VL;
case ISD::VP_FP_TO_SINT:
return RISCVISD::VFCVT_RTZ_X_F_VL;
case ISD::VP_FP_TO_UINT:
return RISCVISD::VFCVT_RTZ_XU_F_VL;
}
// clang-format on
#undef OP_CASE
#undef VP_CASE
}

/// Return true if a RISC-V target specified op has a merge operand.
Expand All @@ -4739,6 +4798,8 @@ static bool hasMergeOp(unsigned Opcode) {
return true;
if (Opcode >= RISCVISD::VWMUL_VL && Opcode <= RISCVISD::VFWSUB_W_VL)
return true;
if (Opcode == RISCVISD::SETCC_VL)
return true;
if (Opcode >= RISCVISD::STRICT_FADD_VL && Opcode <= RISCVISD::STRICT_FDIV_VL)
return true;
return false;
Expand Down Expand Up @@ -5476,106 +5537,72 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::EH_DWARF_CFA:
return lowerEH_DWARF_CFA(Op, DAG);
case ISD::VP_SELECT:
return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL);
case ISD::VP_MERGE:
return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL);
case ISD::VP_ADD:
return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, /*HasMergeOp*/ true);
case ISD::VP_SUB:
return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, /*HasMergeOp*/ true);
case ISD::VP_MUL:
return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, /*HasMergeOp*/ true);
case ISD::VP_SDIV:
return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, /*HasMergeOp*/ true);
case ISD::VP_UDIV:
return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, /*HasMergeOp*/ true);
case ISD::VP_SREM:
return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, /*HasMergeOp*/ true);
case ISD::VP_UREM:
return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
case ISD::VP_AND:
return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL);
case ISD::VP_OR:
return lowerLogicVPOp(Op, DAG, RISCVISD::VMOR_VL, RISCVISD::OR_VL);
case ISD::VP_XOR:
return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL);
return lowerLogicVPOp(Op, DAG);
case ISD::VP_ASHR:
return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, /*HasMergeOp*/ true);
case ISD::VP_LSHR:
return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, /*HasMergeOp*/ true);
case ISD::VP_SHL:
return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, /*HasMergeOp*/ true);
case ISD::VP_FADD:
return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, /*HasMergeOp*/ true);
case ISD::VP_FSUB:
return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, /*HasMergeOp*/ true);
case ISD::VP_FMUL:
return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, /*HasMergeOp*/ true);
case ISD::VP_FDIV:
return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, /*HasMergeOp*/ true);
case ISD::VP_FNEG:
return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL);
case ISD::VP_FABS:
return lowerVPOp(Op, DAG, RISCVISD::FABS_VL);
case ISD::VP_SQRT:
return lowerVPOp(Op, DAG, RISCVISD::FSQRT_VL);
case ISD::VP_FMA:
return lowerVPOp(Op, DAG, RISCVISD::VFMADD_VL);
case ISD::VP_FMINNUM:
return lowerVPOp(Op, DAG, RISCVISD::FMINNUM_VL, /*HasMergeOp*/ true);
case ISD::VP_FMAXNUM:
return lowerVPOp(Op, DAG, RISCVISD::FMAXNUM_VL, /*HasMergeOp*/ true);
case ISD::VP_FCOPYSIGN:
return lowerVPOp(Op, DAG, RISCVISD::FCOPYSIGN_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
case ISD::VP_SIGN_EXTEND:
case ISD::VP_ZERO_EXTEND:
if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
return lowerVPExtMaskOp(Op, DAG);
return lowerVPOp(Op, DAG,
Op.getOpcode() == ISD::VP_SIGN_EXTEND
? RISCVISD::VSEXT_VL
: RISCVISD::VZEXT_VL);
return lowerVPOp(Op, DAG);
case ISD::VP_TRUNCATE:
return lowerVectorTruncLike(Op, DAG);
case ISD::VP_FP_EXTEND:
case ISD::VP_FP_ROUND:
return lowerVectorFPExtendOrRoundLike(Op, DAG);
case ISD::VP_FP_TO_SINT:
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_X_F_VL);
case ISD::VP_FP_TO_UINT:
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::VFCVT_RTZ_XU_F_VL);
case ISD::VP_SINT_TO_FP:
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::SINT_TO_FP_VL);
case ISD::VP_UINT_TO_FP:
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::UINT_TO_FP_VL);
return lowerVPFPIntConvOp(Op, DAG);
case ISD::VP_SETCC:
if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
return lowerVPSetCCMaskOp(Op, DAG);
return lowerVPOp(Op, DAG, RISCVISD::SETCC_VL, /*HasMergeOp*/ true);
[[fallthrough]];
case ISD::VP_SMIN:
return lowerVPOp(Op, DAG, RISCVISD::SMIN_VL, /*HasMergeOp*/ true);
case ISD::VP_SMAX:
return lowerVPOp(Op, DAG, RISCVISD::SMAX_VL, /*HasMergeOp*/ true);
case ISD::VP_UMIN:
return lowerVPOp(Op, DAG, RISCVISD::UMIN_VL, /*HasMergeOp*/ true);
case ISD::VP_UMAX:
return lowerVPOp(Op, DAG, RISCVISD::UMAX_VL, /*HasMergeOp*/ true);
case ISD::VP_BITREVERSE:
return lowerVPOp(Op, DAG, RISCVISD::BITREVERSE_VL, /*HasMergeOp*/ true);
case ISD::VP_BSWAP:
return lowerVPOp(Op, DAG, RISCVISD::BSWAP_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
case ISD::VP_CTLZ:
case ISD::VP_CTLZ_ZERO_UNDEF:
if (Subtarget.hasStdExtZvbb())
return lowerVPOp(Op, DAG, RISCVISD::CTLZ_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
case ISD::VP_CTTZ:
case ISD::VP_CTTZ_ZERO_UNDEF:
if (Subtarget.hasStdExtZvbb())
return lowerVPOp(Op, DAG, RISCVISD::CTTZ_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
case ISD::VP_CTPOP:
return lowerVPOp(Op, DAG, RISCVISD::CTPOP_VL, /*HasMergeOp*/ true);
return lowerVPOp(Op, DAG);
case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
return lowerVPStridedLoad(Op, DAG);
case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
Expand Down Expand Up @@ -8827,9 +8854,10 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op,
// * The EVL operand is promoted from i32 to i64 on RV64.
// * Fixed-length vectors are converted to their scalable-vector container
// types.
SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG,
unsigned RISCVISDOpc,
bool HasMergeOp) const {
SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
unsigned RISCVISDOpc = getRISCVVLOp(Op);
bool HasMergeOp = hasMergeOp(RISCVISDOpc);

SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
SmallVector<SDValue, 4> Ops;
Expand Down Expand Up @@ -8978,13 +9006,14 @@ SDValue RISCVTargetLowering::lowerVPSetCCMaskOp(SDValue Op,
}

// Lower Floating-Point/Integer Type-Convert VP SDNodes
SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
unsigned RISCVISDOpc) const {
SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);

SDValue Src = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue VL = Op.getOperand(2);
unsigned RISCVISDOpc = getRISCVVLOp(Op);

MVT DstVT = Op.getSimpleValueType();
MVT SrcVT = Src.getSimpleValueType();
Expand Down Expand Up @@ -9110,12 +9139,11 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VT, Result, DAG, Subtarget);
}

SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG,
unsigned MaskOpc,
unsigned VecOpc) const {
SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op,
SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
if (VT.getVectorElementType() != MVT::i1)
return lowerVPOp(Op, DAG, VecOpc, true);
return lowerVPOp(Op, DAG);

// It is safe to drop mask parameter as masked-off elements are undef.
SDValue Op1 = Op->getOperand(0);
Expand All @@ -9131,7 +9159,7 @@ SDValue RISCVTargetLowering::lowerLogicVPOp(SDValue Op, SelectionDAG &DAG,
}

SDLoc DL(Op);
SDValue Val = DAG.getNode(MaskOpc, DL, ContainerVT, Op1, Op2, VL);
SDValue Val = DAG.getNode(getRISCVVLOp(Op), DL, ContainerVT, Op1, Op2, VL);
if (!IsFixed)
return Val;
return convertFromScalableVector(VT, Val, DAG, Subtarget);
Expand Down
9 changes: 3 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,11 @@ class RISCVTargetLowering : public TargetLowering {
SelectionDAG &DAG) const;
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc,
bool HasMergeOp = false) const;
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc,
unsigned VecOpc) const;
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
unsigned RISCVISDOpc) const;
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
Expand Down

0 comments on commit 5d6d649

Please sign in to comment.