diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a371d3bef15d9..9f8b183635012 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13352,7 +13352,7 @@ static bool IsSVECntIntrinsic(SDValue S) { /// /// \returns The type representing the \p Extend source type, or \p MVT::Other /// if no valid type can be determined -static EVT calculatePreExtendType(SDValue Extend, SelectionDAG &DAG) { +static EVT calculatePreExtendType(SDValue Extend) { switch (Extend.getOpcode()) { case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: @@ -13385,15 +13385,12 @@ static EVT calculatePreExtendType(SDValue Extend, SelectionDAG &DAG) { default: return MVT::Other; } - - llvm_unreachable("Code path unhandled in calculatePreExtendType!"); } /// Combines a dup(sext/zext) node pattern into sext/zext(dup) /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle, SelectionDAG &DAG) { - ShuffleVectorSDNode *ShuffleNode = dyn_cast(VectorShuffle.getNode()); if (!ShuffleNode) @@ -13424,24 +13421,14 @@ static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle, ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND) return SDValue(); - EVT TargetType = VectorShuffle.getValueType(); - EVT PreExtendType = calculatePreExtendType(Extend, DAG); - - if ((TargetType != MVT::v8i16 && TargetType != MVT::v4i32 && - TargetType != MVT::v2i64) || - (PreExtendType == MVT::Other)) - return SDValue(); - // Restrict valid pre-extend data type + EVT PreExtendType = calculatePreExtendType(Extend); if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 && PreExtendType != MVT::i32) return SDValue(); + EVT TargetType = VectorShuffle.getValueType(); EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType); - - if (PreExtendVT.getVectorElementCount() != TargetType.getVectorElementCount()) - return SDValue(); - if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2) return SDValue(); @@ -13458,17 +13445,16 @@ static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle, DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode, DAG.getUNDEF(PreExtendVT), ShuffleMask); - SDValue ExtendNode = DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, - DL, TargetType, VectorShuffleNode); - - return ExtendNode; + return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, + TargetType, VectorShuffleNode); } /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup)) /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { // If the value type isn't a vector, none of the operands are going to be dups - if (!Mul->getValueType(0).isVector()) + EVT VT = Mul->getValueType(0); + if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64) return SDValue(); SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG); @@ -13479,8 +13465,7 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { return SDValue(); SDLoc DL(Mul); - return DAG.getNode(Mul->getOpcode(), DL, Mul->getValueType(0), - Op0 ? Op0 : Mul->getOperand(0), + return DAG.getNode(Mul->getOpcode(), DL, VT, Op0 ? Op0 : Mul->getOperand(0), Op1 ? Op1 : Mul->getOperand(1)); }