diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 0ca5e89263b72..0d64f062df39e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -22515,138 +22515,6 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc, return SDValue(); } -SDValue tryLowerPartialReductionToDot(SDNode *N, - const AArch64Subtarget *Subtarget, - SelectionDAG &DAG) { - - assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && - getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add && - "Expected a partial reduction node"); - - bool Scalable = N->getValueType(0).isScalableVector(); - if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable()) - return SDValue(); - if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd())) - return SDValue(); - - SDLoc DL(N); - - SDValue Op2 = N->getOperand(2); - unsigned Op2Opcode = Op2->getOpcode(); - SDValue MulOpLHS, MulOpRHS; - bool MulOpLHSIsSigned, MulOpRHSIsSigned; - if (ISD::isExtOpcode(Op2Opcode)) { - MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND); - MulOpLHS = Op2->getOperand(0); - MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType()); - } else if (Op2Opcode == ISD::MUL) { - SDValue ExtMulOpLHS = Op2->getOperand(0); - SDValue ExtMulOpRHS = Op2->getOperand(1); - - unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode(); - unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode(); - if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) || - !ISD::isExtOpcode(ExtMulOpRHSOpcode)) - return SDValue(); - - MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND; - MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND; - - MulOpLHS = ExtMulOpLHS->getOperand(0); - MulOpRHS = ExtMulOpRHS->getOperand(0); - - if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) - return SDValue(); - } else - return SDValue(); - - SDValue Acc = N->getOperand(1); - EVT ReducedVT = N->getValueType(0); - EVT MulSrcVT = MulOpLHS.getValueType(); - - // Dot products operate on chunks of four elements so there must be four times - // as many elements in the wide type - if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) && - !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) && - !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) && - !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) && - !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) && - !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) - return SDValue(); - - // If the extensions are mixed, we should lower it to a usdot instead - unsigned Opcode = 0; - if (MulOpLHSIsSigned != MulOpRHSIsSigned) { - if (!Subtarget->hasMatMulInt8()) - return SDValue(); - - bool Scalable = N->getValueType(0).isScalableVT(); - // There's no nxv2i64 version of usdot - if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64) - return SDValue(); - - Opcode = AArch64ISD::USDOT; - // USDOT expects the signed operand to be last - if (!MulOpRHSIsSigned) - std::swap(MulOpLHS, MulOpRHS); - } else - Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT; - - // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot - // product followed by a zero / sign extension - if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) || - (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) { - EVT ReducedVTI32 = - (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32; - - SDValue DotI32 = - DAG.getNode(Opcode, DL, ReducedVTI32, - DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS); - SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT); - return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended); - } - - return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS); -} - -SDValue tryLowerPartialReductionToWideAdd(SDNode *N, - const AArch64Subtarget *Subtarget, - SelectionDAG &DAG) { - - assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && - getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add && - "Expected a partial reduction node"); - - if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable()) - return SDValue(); - - SDLoc DL(N); - - if (!ISD::isExtOpcode(N->getOperand(2).getOpcode())) - return SDValue(); - SDValue Acc = N->getOperand(1); - SDValue Ext = N->getOperand(2); - EVT AccVT = Acc.getValueType(); - EVT ExtVT = Ext.getValueType(); - if (ExtVT.getVectorElementType() != AccVT.getVectorElementType()) - return SDValue(); - - SDValue ExtOp = Ext->getOperand(0); - EVT ExtOpVT = ExtOp.getValueType(); - - if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) && - !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) && - !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16)) - return SDValue(); - - bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND; - unsigned BottomOpcode = - ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB; - unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT; - SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp); - return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp); -} - static SDValue combineSVEBitSel(unsigned IID, SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); EVT VT = N->getValueType(0); @@ -22679,17 +22547,6 @@ static SDValue performIntrinsicCombine(SDNode *N, switch (IID) { default: break; - case Intrinsic::vector_partial_reduce_add: { - if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) - return Dot; - if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG)) - return WideAdd; - SDLoc DL(N); - SDValue Input = N->getOperand(2); - return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, N->getValueType(0), - N->getOperand(1), Input, - DAG.getConstant(1, DL, Input.getValueType())); - } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: return tryCombineFixedPointConvert(N, DCI, DAG);