Skip to content

Commit

Permalink
[RISCV] Add new entry points to getContainerForFixedLengthVector
Browse files Browse the repository at this point in the history
While working on adding fixed-length vectors to the calling convention,
it was necessary to be able to query for a fixed-length vector container
type without access to an instance of SelectionDAG.

This patch modifies the "main" getContainerForFixedLengthVector function
to use an instance of TargetLowering rather than SelectionDAG, and
preserves the SelectionDAG overload as a wrapper.

An additional non-static version of the function was also added to
simplify the common case in RISCVTargetLowering.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D97925
  • Loading branch information
frasercrmck committed Mar 8, 2021
1 parent 63851a7 commit 18173c5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 36 deletions.
66 changes: 30 additions & 36 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,8 @@ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(

// Return the largest legal scalable vector type that matches VT's element type.
MVT RISCVTargetLowering::getContainerForFixedLengthVector(
SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) {
assert(VT.isFixedLengthVector() &&
DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
const TargetLowering &TLI, MVT VT, const RISCVSubtarget &Subtarget) {
assert(VT.isFixedLengthVector() && TLI.isTypeLegal(VT) &&
"Expected legal fixed length vector!");

unsigned LMul = Subtarget.getLMULForFixedLengthVector(VT);
Expand Down Expand Up @@ -976,6 +975,16 @@ MVT RISCVTargetLowering::getContainerForFixedLengthVector(
}
}

MVT RISCVTargetLowering::getContainerForFixedLengthVector(
SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) {
return getContainerForFixedLengthVector(DAG.getTargetLoweringInfo(), VT,
Subtarget);
}

MVT RISCVTargetLowering::getContainerForFixedLengthVector(MVT VT) const {
return getContainerForFixedLengthVector(*this, VT, getSubtarget());
}

// Grow V to consume an entire RVV register.
static SDValue convertToScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
Expand Down Expand Up @@ -1250,8 +1259,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,

MVT ContainerVT = SrcVT;
if (SrcVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, SrcVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(SrcVT);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

Expand Down Expand Up @@ -1313,8 +1321,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
// Prepare any fixed-length vector operands.
MVT ContainerVT = VT;
if (SrcVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VT);
MVT SrcContainerVT =
ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
Expand Down Expand Up @@ -1354,9 +1361,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
// Prepare any fixed-length vector operands.
MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
MVT SrcContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, SrcVT,
Subtarget);
MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
ContainerVT =
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
Expand Down Expand Up @@ -1474,13 +1479,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
MVT ContainerVT, SrcContainerVT;
// Derive the reference container type from the larger vector type.
if (SrcEltSize > EltSize) {
SrcContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, SrcVT, Subtarget);
SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
ContainerVT =
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
} else {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VT);
SrcContainerVT = ContainerVT.changeVectorElementType(SrcEltVT);
}

Expand Down Expand Up @@ -2097,8 +2100,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero);
}

MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VecVT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
MVT I1ContainerVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());

Expand Down Expand Up @@ -2126,8 +2128,7 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorExtendToRVV(
// Grab the canonical container type for the extended type. Infer the smaller
// type from that to ensure the same number of vector elements, as we know
// the LMUL will be sufficient to hold the smaller type.
MVT ContainerExtVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, ExtVT, Subtarget);
MVT ContainerExtVT = getContainerForFixedLengthVector(ExtVT);
// Get the extended container type manually to ensure the same number of
// vector elements between source and dest.
MVT ContainerVT = MVT::getVectorVT(VT.getVectorElementType(),
Expand Down Expand Up @@ -2161,7 +2162,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op,
// If this is a fixed vector, we need to convert it to a scalable vector.
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

Expand Down Expand Up @@ -2198,7 +2199,7 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
MVT ContainerVT = VecVT;
// If the operand is a fixed-length vector, convert to a scalable one.
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}

Expand Down Expand Up @@ -2265,7 +2266,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
// If this is a fixed vector, we need to convert it to a scalable vector.
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}

Expand Down Expand Up @@ -2530,8 +2531,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
return Op;
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}
SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT,
Expand Down Expand Up @@ -2682,8 +2682,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
return Op;
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}
SDValue Mask =
Expand Down Expand Up @@ -2753,8 +2752,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,

SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(VT);

SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
Expand All @@ -2779,8 +2777,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
// FIXME: We probably need to zero any extra bits in a byte for mask stores.
// This is tricky to do.

MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(VT);

SDValue VL =
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
Expand All @@ -2797,8 +2794,7 @@ SDValue
RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
SelectionDAG &DAG) const {
MVT InVT = Op.getOperand(0).getSimpleValueType();
MVT ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector(
DAG, InVT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(InVT);

MVT VT = Op.getSimpleValueType();

Expand Down Expand Up @@ -2911,8 +2907,7 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorLogicOpToRVV(
SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(VT);

MVT I1ContainerVT =
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
Expand Down Expand Up @@ -2940,8 +2935,7 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
MVT VT = Op.getSimpleValueType();
assert(useRVVForFixedLengthVectorVT(VT) &&
"Only expected to lower fixed length vector operation!");
MVT ContainerVT =
RISCVTargetLowering::getContainerForFixedLengthVector(DAG, VT, Subtarget);
MVT ContainerVT = getContainerForFixedLengthVector(VT);

// Create list of operands by converting existing ones to scalable types.
SmallVector<SDValue, 6> Ops;
Expand Down Expand Up @@ -3242,7 +3236,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
// If this is a fixed vector, we need to convert it to a scalable vector.
MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget);
ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ class RISCVTargetLowering : public TargetLowering {
decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT,
unsigned InsertExtractIdx,
const RISCVRegisterInfo *TRI);
MVT getContainerForFixedLengthVector(MVT VT) const;
static MVT getContainerForFixedLengthVector(const TargetLowering &TLI, MVT VT,
const RISCVSubtarget &Subtarget);
static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT,
const RISCVSubtarget &Subtarget);

Expand Down

0 comments on commit 18173c5

Please sign in to comment.