Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 0 additions & 143 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22515,138 +22515,6 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
return SDValue();
}

SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
"Expected a partial reduction node");

bool Scalable = N->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();

SDLoc DL(N);

SDValue Op2 = N->getOperand(2);
unsigned Op2Opcode = Op2->getOpcode();
SDValue MulOpLHS, MulOpRHS;
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
if (ISD::isExtOpcode(Op2Opcode)) {
MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
MulOpLHS = Op2->getOperand(0);
MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
} else if (Op2Opcode == ISD::MUL) {
SDValue ExtMulOpLHS = Op2->getOperand(0);
SDValue ExtMulOpRHS = Op2->getOperand(1);

unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;

MulOpLHS = ExtMulOpLHS->getOperand(0);
MulOpRHS = ExtMulOpRHS->getOperand(0);

if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();
} else
return SDValue();

SDValue Acc = N->getOperand(1);
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();

// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();

bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();

Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
} else
Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;

// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
EVT ReducedVTI32 =
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;

SDValue DotI32 =
DAG.getNode(Opcode, DL, ReducedVTI32,
DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}

return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}

SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
"Expected a partial reduction node");

if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
return SDValue();

SDLoc DL(N);

if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
return SDValue();
SDValue Acc = N->getOperand(1);
SDValue Ext = N->getOperand(2);
EVT AccVT = Acc.getValueType();
EVT ExtVT = Ext.getValueType();
if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
return SDValue();

SDValue ExtOp = Ext->getOperand(0);
EVT ExtOpVT = ExtOp.getValueType();

if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();

bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
unsigned BottomOpcode =
ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
}

static SDValue combineSVEBitSel(unsigned IID, SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
Expand Down Expand Up @@ -22679,17 +22547,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
case Intrinsic::vector_partial_reduce_add: {
if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
SDLoc DL(N);
SDValue Input = N->getOperand(2);
return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, N->getValueType(0),
N->getOperand(1), Input,
DAG.getConstant(1, DL, Input.getValueType()));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
Expand Down