diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 52799cfd443b15..f59f9d2ad87cae 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -945,9 +945,8 @@ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs( // Return the largest legal scalable vector type that matches VT's element type. MVT RISCVTargetLowering::getContainerForFixedLengthVector( - SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) { - assert(VT.isFixedLengthVector() && - DAG.getTargetLoweringInfo().isTypeLegal(VT) && + const TargetLowering &TLI, MVT VT, const RISCVSubtarget &Subtarget) { + assert(VT.isFixedLengthVector() && TLI.isTypeLegal(VT) && "Expected legal fixed length vector!"); unsigned LMul = Subtarget.getLMULForFixedLengthVector(VT); @@ -976,6 +975,16 @@ MVT RISCVTargetLowering::getContainerForFixedLengthVector( } } +MVT RISCVTargetLowering::getContainerForFixedLengthVector( + SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) { + return getContainerForFixedLengthVector(DAG.getTargetLoweringInfo(), VT, + Subtarget); +} + +MVT RISCVTargetLowering::getContainerForFixedLengthVector(MVT VT) const { + return getContainerForFixedLengthVector(*this, VT, getSubtarget()); +} + // Grow V to consume an entire RVV register. static SDValue convertToScalableVector(EVT VT, SDValue V, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { @@ -1250,8 +1259,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, MVT ContainerVT = SrcVT; if (SrcVT.isFixedLengthVector()) { - ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, SrcVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(SrcVT); Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); } @@ -1313,8 +1321,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, // Prepare any fixed-length vector operands. MVT ContainerVT = VT; if (SrcVT.isFixedLengthVector()) { - ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, VT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VT); MVT SrcContainerVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType()); Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); @@ -1354,9 +1361,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, // Prepare any fixed-length vector operands. MVT ContainerVT = VT; if (VT.isFixedLengthVector()) { - MVT SrcContainerVT = - RISCVTargetLowering::getContainerForFixedLengthVector(DAG, SrcVT, - Subtarget); + MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT); ContainerVT = SrcContainerVT.changeVectorElementType(VT.getVectorElementType()); Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget); @@ -1474,13 +1479,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, MVT ContainerVT, SrcContainerVT; // Derive the reference container type from the larger vector type. if (SrcEltSize > EltSize) { - SrcContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, SrcVT, Subtarget); + SrcContainerVT = getContainerForFixedLengthVector(SrcVT); ContainerVT = SrcContainerVT.changeVectorElementType(VT.getVectorElementType()); } else { - ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, VT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VT); SrcContainerVT = ContainerVT.changeVectorElementType(SrcEltVT); } @@ -2097,8 +2100,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero); } - MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, VecVT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(VecVT); MVT I1ContainerVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); @@ -2126,8 +2128,7 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV( // Grab the canonical container type for the extended type. Infer the smaller // type from that to ensure the same number of vector elements, as we know // the LMUL will be sufficient to hold the smaller type. - MVT ContainerExtVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, ExtVT, Subtarget); + MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT); // Get the extended container type manually to ensure the same number of // vector elements between source and dest. MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(), @@ -2161,7 +2162,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op, // If this is a fixed vector, we need to convert it to a scalable vector. MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); } @@ -2198,7 +2199,7 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, MVT ContainerVT = VecVT; // If the operand is a fixed-length vector, convert to a scalable one. if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } @@ -2265,7 +2266,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, // If this is a fixed vector, we need to convert it to a scalable vector. MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } @@ -2530,8 +2531,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op, return Op; MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, @@ -2682,8 +2682,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, return Op; MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } SDValue Mask = @@ -2753,8 +2752,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op, SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); - MVT ContainerVT = - RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(VT); SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); @@ -2779,8 +2777,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op, // FIXME: We probably need to zero any extra bits in a byte for mask stores. // This is tricky to do. - MVT ContainerVT = - RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(VT); SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); @@ -2797,8 +2794,7 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const { MVT InVT = Op.getOperand(0).getSimpleValueType(); - MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( - DAG, InVT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(InVT); MVT VT = Op.getSimpleValueType(); @@ -2911,8 +2907,7 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV( SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV( SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); - MVT ContainerVT = - RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(VT); MVT I1ContainerVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); @@ -2940,8 +2935,7 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG, MVT VT = Op.getSimpleValueType(); assert(useRVVForFixedLengthVectorVT(VT) && "Only expected to lower fixed length vector operation!"); - MVT ContainerVT = - RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget); + MVT ContainerVT = getContainerForFixedLengthVector(VT); // Create list of operands by converting existing ones to scalable types. SmallVector Ops; @@ -3242,7 +3236,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, // If this is a fixed vector, we need to convert it to a scalable vector. MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); + ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 49f1767dc5d95e..abbbb914f21c60 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -401,6 +401,9 @@ class RISCVTargetLowering : public TargetLowering { decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT, unsigned InsertExtractIdx, const RISCVRegisterInfo *TRI); + MVT getContainerForFixedLengthVector(MVT VT) const; + static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT, + const RISCVSubtarget &Subtarget); static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget);