Skip to content

Commit

Permalink
[ARM] Expand types handled in VQDMULH recognition
Browse files Browse the repository at this point in the history
We have a DAG combine for recognizing the sequence of nodes that make up
an MVE VQDMULH, but only currently handles specifically legal types.
This patch expands that to other power-2 vector types. For smaller than
legal types this means any_extending the type and casting it to a legal
type, using a VQDMULH where we only use some of the lanes. The result is
sign extended back to the original type, to properly set the invalid
lanes. Larger than legal types are split into chunks with extracts and
concat back together.

Differential Revision: https://reviews.llvm.org/D105814
  • Loading branch information
davemgreen committed Jul 15, 2021
1 parent 5d7632e commit dad506b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 636 deletions.
41 changes: 37 additions & 4 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12933,17 +12933,50 @@ static SDValue PerformVQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
Ext1.getOpcode() != ISD::SIGN_EXTEND)
return SDValue();
EVT VecVT = Ext0.getOperand(0).getValueType();
if (VecVT != MVT::v4i32 && VecVT != MVT::v8i16 && VecVT != MVT::v16i8)
if (!VecVT.isPow2VectorType() || VecVT.getVectorNumElements() == 1)
return SDValue();
if (Ext1.getOperand(0).getValueType() != VecVT ||
VecVT.getScalarType() != ScalarType ||
VT.getScalarSizeInBits() < ScalarType.getScalarSizeInBits() * 2)
return SDValue();

SDLoc DL(Mul);
SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, VecVT, Ext0.getOperand(0),
Ext1.getOperand(0));
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, VQDMULH);
unsigned LegalLanes = 128 / (ShftAmt + 1);
EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes);
// For types smaller than legal vectors extend to be legal and only use needed
// lanes.
if (VecVT.getSizeInBits() < 128) {
EVT ExtVecVT =
MVT::getVectorVT(MVT::getIntegerVT(128 / VecVT.getVectorNumElements()),
VecVT.getVectorNumElements());
SDValue Inp0 =
DAG.getNode(ISD::ANY_EXTEND, DL, ExtVecVT, Ext0.getOperand(0));
SDValue Inp1 =
DAG.getNode(ISD::ANY_EXTEND, DL, ExtVecVT, Ext1.getOperand(0));
Inp0 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, LegalVecVT, Inp0);
Inp1 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, LegalVecVT, Inp1);
SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, LegalVecVT, Inp0, Inp1);
SDValue Trunc = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, ExtVecVT, VQDMULH);
Trunc = DAG.getNode(ISD::TRUNCATE, DL, VecVT, Trunc);
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Trunc);
}

// For larger types, split into legal sized chunks.
assert(VecVT.getSizeInBits() % 128 == 0 && "Expected a power2 type");
unsigned NumParts = VecVT.getSizeInBits() / 128;
SmallVector<SDValue> Parts;
for (unsigned I = 0; I < NumParts; ++I) {
SDValue Inp0 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LegalVecVT, Ext0.getOperand(0),
DAG.getVectorIdxConstant(I * LegalLanes, DL));
SDValue Inp1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LegalVecVT, Ext1.getOperand(0),
DAG.getVectorIdxConstant(I * LegalLanes, DL));
SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, LegalVecVT, Inp0, Inp1);
Parts.push_back(VQDMULH);
}
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT,
DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Parts));
}

static SDValue PerformVSELECTCombine(SDNode *N,
Expand Down
Loading

0 comments on commit dad506b

Please sign in to comment.