diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index b0c5fcd53c41e..520f593341190 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -48,6 +48,23 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() { I != E;) { SDNode *N = &*I++; // Preincrement iterator to avoid invalidation issues. + // Convert integer SPLAT_VECTOR to VMV_V_X_VL to reduce isel burden. + if (N->getOpcode() == ISD::SPLAT_VECTOR && + N->getSimpleValueType(0).isInteger()) { + MVT VT = N->getSimpleValueType(0); + SDLoc DL(N); + SDValue VL = CurDAG->getTargetConstant(RISCV::VLMaxSentinel, DL, + Subtarget->getXLenVT()); + SDValue Result = + CurDAG->getNode(RISCVISD::VMV_V_X_VL, DL, VT, N->getOperand(0), VL); + + --I; + CurDAG->ReplaceAllUsesOfValueWith(SDValue(N, 0), Result); + ++I; + CurDAG->DeleteNode(N); + continue; + } + // Lower SPLAT_VECTOR_SPLIT_I64 to two scalar stores and a stride 0 vector // load. Done after lowering and combining so that we have a chance to // optimize this to VMV_V_X_VL when the upper bits aren't needed. @@ -1881,8 +1898,7 @@ bool RISCVDAGToDAGISel::selectVLOp(SDValue N, SDValue &VL) { } bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) { - if (N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) + if (N.getOpcode() != RISCVISD::VMV_V_X_VL) return false; SplatVal = N.getOperand(0); return true; @@ -1894,14 +1910,13 @@ static bool selectVSplatSimmHelper(SDValue N, SDValue &SplatVal, SelectionDAG &DAG, const RISCVSubtarget &Subtarget, ValidateFn ValidateImm) { - if ((N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) || + if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !isa(N.getOperand(0))) return false; int64_t SplatImm = cast(N.getOperand(0))->getSExtValue(); - // ISD::SPLAT_VECTOR, RISCVISD::VMV_V_X_VL share semantics when the operand + // The semantics of RISCVISD::VMV_V_X_VL is that when the operand // type is wider than the resulting vector element type: an implicit // truncation first takes place. Therefore, perform a manual // truncation/sign-extension in order to ignore any truncated bits and catch @@ -1942,8 +1957,7 @@ bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1NonZero(SDValue N, } bool RISCVDAGToDAGISel::selectVSplatUimm5(SDValue N, SDValue &SplatVal) { - if ((N.getOpcode() != ISD::SPLAT_VECTOR && - N.getOpcode() != RISCVISD::VMV_V_X_VL) || + if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !isa(N.getOperand(0))) return false; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 6fa2d55b80c01..b74bef673575c 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -583,12 +583,6 @@ defm : VPatBinarySDNode_VV_VX_VI; foreach vti = AllIntegerVectors in { // Emit shift by 1 as an add since it might be faster. - def : Pat<(shl (vti.Vector vti.RegClass:$rs1), - (vti.Vector (splat_vector (XLenVT 1)))), - (!cast("PseudoVADD_VV_"# vti.LMul.MX) - vti.RegClass:$rs1, vti.RegClass:$rs1, vti.AVL, vti.Log2SEW)>; -} -foreach vti = [VI64M1, VI64M2, VI64M4, VI64M8] in { def : Pat<(shl (vti.Vector vti.RegClass:$rs1), (vti.Vector (riscv_vmv_v_x_vl 1, (XLenVT srcvalue)))), (!cast("PseudoVADD_VV_"# vti.LMul.MX) @@ -943,17 +937,6 @@ foreach fvtiToFWti = AllWidenableFloatVectors in { // Vector Splats //===----------------------------------------------------------------------===// -let Predicates = [HasVInstructions] in { -foreach vti = AllIntegerVectors in { - def : Pat<(vti.Vector (SplatPat GPR:$rs1)), - (!cast("PseudoVMV_V_X_" # vti.LMul.MX) - GPR:$rs1, vti.AVL, vti.Log2SEW)>; - def : Pat<(vti.Vector (SplatPat_simm5 simm5:$rs1)), - (!cast("PseudoVMV_V_I_" # vti.LMul.MX) - simm5:$rs1, vti.AVL, vti.Log2SEW)>; -} -} // Predicates = [HasVInstructions] - let Predicates = [HasVInstructionsAnyF] in { foreach fvti = AllFloatVectors in { def : Pat<(fvti.Vector (splat_vector fvti.ScalarRegClass:$rs1)), diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index f146242e01de2..ddaf393112f98 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -279,15 +279,13 @@ foreach kind = ["ADD", "UMAX", "SMAX", "UMIN", "SMIN", "AND", "OR", "XOR", def rvv_vecreduce_#kind#_vl : SDNode<"RISCVISD::VECREDUCE_"#kind#"_VL", SDTRVVVecReduce>; // Give explicit Complexity to prefer simm5/uimm5. -def SplatPat : ComplexPattern; -def SplatPat_simm5 : ComplexPattern; -def SplatPat_uimm5 : ComplexPattern; +def SplatPat : ComplexPattern; +def SplatPat_simm5 : ComplexPattern; +def SplatPat_uimm5 : ComplexPattern; def SplatPat_simm5_plus1 - : ComplexPattern; + : ComplexPattern; def SplatPat_simm5_plus1_nonzero - : ComplexPattern; + : ComplexPattern; // Ignore the vl operand. def SplatFPOp : PatFrag<(ops node:$op),