Skip to content

Commit

Permalink
[AArch64] Try to combine MULL with uzp1.
Browse files Browse the repository at this point in the history
For example,

 smull(trunc(x), extract_high(y))
 ==>
 smull(extract_high(uzp1(undef,x)), extract_high(y))

 -> It will be matched to smull2

Differential Revision: https://reviews.llvm.org/D150969
  • Loading branch information
jaykang10 committed Jun 13, 2023
1 parent 5cdb906 commit 16daaf0
Show file tree
Hide file tree
Showing 2 changed files with 377 additions and 11 deletions.
148 changes: 147 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -22391,6 +22391,152 @@ static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::BITCAST, DL, VT, NewDuplane128);
}

// Try to combine mull with uzp1.
static SDValue tryCombineMULLWithUZP1(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
if (DCI.isBeforeLegalizeOps())
return SDValue();

SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

SDValue ExtractHigh;
SDValue ExtractLow;
SDValue TruncHigh;
SDValue TruncLow;
SDLoc DL(N);

// Check the operands are trunc and extract_high.
if (isEssentiallyExtractHighSubvector(LHS) &&
RHS.getOpcode() == ISD::TRUNCATE) {
TruncHigh = RHS;
if (LHS.getOpcode() == ISD::BITCAST)
ExtractHigh = LHS.getOperand(0);
else
ExtractHigh = LHS;
} else if (isEssentiallyExtractHighSubvector(RHS) &&
LHS.getOpcode() == ISD::TRUNCATE) {
TruncHigh = LHS;
if (LHS.getOpcode() == ISD::BITCAST)
ExtractHigh = RHS.getOperand(0);
else
ExtractHigh = RHS;
} else
return SDValue();

// If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
// with uzp1.
// You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
SDValue TruncHighOp = TruncHigh.getOperand(0);
EVT TruncHighOpVT = TruncHighOp.getValueType();
if (TruncHighOp.getOpcode() == AArch64ISD::DUP ||
DAG.isSplatValue(TruncHighOp, false))
return SDValue();

// Check there is other extract_high with same source vector.
// For example,
//
// t18: v4i16 = extract_subvector t2, Constant:i64<0>
// t12: v4i16 = truncate t11
// t31: v4i32 = AArch64ISD::SMULL t18, t12
// t23: v4i16 = extract_subvector t2, Constant:i64<4>
// t16: v4i16 = truncate t15
// t30: v4i32 = AArch64ISD::SMULL t23, t1
//
// This dagcombine assumes the two extract_high uses same source vector in
// order to detect the pair of the mull. If they have different source vector,
// this code will not work.
bool HasFoundMULLow = true;
SDValue ExtractHighSrcVec = ExtractHigh.getOperand(0);
if (ExtractHighSrcVec->use_size() != 2)
HasFoundMULLow = false;

// Find ExtractLow.
for (SDNode *User : ExtractHighSrcVec.getNode()->uses()) {
if (User == ExtractHigh.getNode())
continue;

if (User->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
!isNullConstant(User->getOperand(1))) {
HasFoundMULLow = false;
break;
}

ExtractLow.setNode(User);
}

if (!ExtractLow || !ExtractLow->hasOneUse())
HasFoundMULLow = false;

// Check ExtractLow's user.
if (HasFoundMULLow) {
SDNode *ExtractLowUser = *ExtractLow.getNode()->use_begin();
if (ExtractLowUser->getOpcode() != N->getOpcode())
HasFoundMULLow = false;

if (ExtractLowUser->getOperand(0) == ExtractLow) {
if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE)
TruncLow = ExtractLowUser->getOperand(1);
else
HasFoundMULLow = false;
} else {
if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE)
TruncLow = ExtractLowUser->getOperand(0);
else
HasFoundMULLow = false;
}
}

// If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
// with uzp1.
// You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
EVT TruncHighVT = TruncHigh.getValueType();
EVT UZP1VT = TruncHighVT.getDoubleNumVectorElementsVT(*DAG.getContext());
SDValue TruncLowOp =
HasFoundMULLow ? TruncLow.getOperand(0) : DAG.getUNDEF(UZP1VT);
EVT TruncLowOpVT = TruncLowOp.getValueType();
if (HasFoundMULLow && (TruncLowOp.getOpcode() == AArch64ISD::DUP ||
DAG.isSplatValue(TruncLowOp, false)))
return SDValue();

// Create uzp1, extract_high and extract_low.
if (TruncHighOpVT != UZP1VT)
TruncHighOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncHighOp);
if (TruncLowOpVT != UZP1VT)
TruncLowOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncLowOp);

SDValue UZP1 =
DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TruncLowOp, TruncHighOp);
SDValue HighIdxCst =
DAG.getConstant(TruncHighVT.getVectorNumElements(), DL, MVT::i64);
SDValue NewTruncHigh =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncHighVT, UZP1, HighIdxCst);
DAG.ReplaceAllUsesWith(TruncHigh, NewTruncHigh);

if (HasFoundMULLow) {
EVT TruncLowVT = TruncLow.getValueType();
SDValue NewTruncLow = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncLowVT,
UZP1, ExtractLow.getOperand(1));
DAG.ReplaceAllUsesWith(TruncLow, NewTruncLow);
}

return SDValue(N, 0);
}

static SDValue performMULLCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
if (SDValue Val =
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG))
return Val;

if (SDValue Val = tryCombineMULLWithUZP1(N, DCI, DAG))
return Val;

return SDValue();
}

SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand Down Expand Up @@ -22535,7 +22681,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::SMULL:
case AArch64ISD::UMULL:
case AArch64ISD::PMULL:
return tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG);
return performMULLCombine(N, DCI, DAG);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {
Expand Down

0 comments on commit 16daaf0

Please sign in to comment.