diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 378571da1a6741..fc12f88b970120 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -49631,6 +49631,13 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, if (llvm::all_of(Ops, [Op0](SDValue Op) { return Op.getOpcode() == Op0.getOpcode(); })) { + auto ConcatSubOperand = [&](MVT VT, ArrayRef SubOps, unsigned I) { + SmallVector Subs; + for (SDValue SubOp : SubOps) + Subs.push_back(SubOp.getOperand(I)); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); + }; + unsigned NumOps = Ops.size(); switch (Op0.getOpcode()) { case X86ISD::SHUFP: { @@ -49639,15 +49646,9 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, llvm::all_of(Ops, [Op0](SDValue Op) { return Op.getOperand(2) == Op0.getOperand(2); })) { - SmallVector LHS, RHS; - for (unsigned i = 0; i != NumOps; ++i) { - LHS.push_back(Ops[i].getOperand(0)); - RHS.push_back(Ops[i].getOperand(1)); - } return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LHS), - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, RHS), - Op0.getOperand(2)); + ConcatSubOperand(VT, Ops, 0), + ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2)); } break; } @@ -49656,22 +49657,15 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, case X86ISD::PSHUFD: if (!IsSplat && NumOps == 2 && VT.is256BitVector() && Subtarget.hasInt256() && Op0.getOperand(1) == Ops[1].getOperand(1)) { - SmallVector Src; - for (unsigned i = 0; i != NumOps; ++i) - Src.push_back(Ops[i].getOperand(0)); return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Src), - Op0.getOperand(1)); + ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1)); } LLVM_FALLTHROUGH; case X86ISD::VPERMILPI: // TODO - add support for vXf64/vXi64 shuffles. if (!IsSplat && NumOps == 2 && (VT == MVT::v8f32 || VT == MVT::v8i32) && Subtarget.hasAVX() && Op0.getOperand(1) == Ops[1].getOperand(1)) { - SmallVector Src; - for (unsigned i = 0; i != NumOps; ++i) - Src.push_back(DAG.getBitcast(MVT::v4f32, Ops[i].getOperand(0))); - SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8f32, Src); + SDValue Res = DAG.getBitcast(MVT::v8f32, ConcatSubOperand(VT, Ops, 0)); Res = DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v8f32, Res, Op0.getOperand(1)); return DAG.getBitcast(VT, Res); @@ -49717,12 +49711,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, llvm::all_of(Ops, [Op0](SDValue Op) { return Op0.getOperand(1) == Op.getOperand(1); })) { - SmallVector Src; - for (unsigned i = 0; i != NumOps; ++i) - Src.push_back(Ops[i].getOperand(0)); return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Src), - Op0.getOperand(1)); + ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1)); } break; case X86ISD::VPERMI: @@ -49732,12 +49722,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, llvm::all_of(Ops, [Op0](SDValue Op) { return Op0.getOperand(1) == Op.getOperand(1); })) { - SmallVector Src; - for (unsigned i = 0; i != NumOps; ++i) - Src.push_back(Ops[i].getOperand(0)); return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Src), - Op0.getOperand(1)); + ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1)); } break; case ISD::AND: @@ -49746,17 +49732,12 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, case X86ISD::ANDNP: // TODO: Add 256-bit support. if (!IsSplat && VT.is512BitVector()) { - SmallVector LHS, RHS; - for (unsigned i = 0; i != NumOps; ++i) { - LHS.push_back(Ops[i].getOperand(0)); - RHS.push_back(Ops[i].getOperand(1)); - } MVT SrcVT = Op0.getOperand(0).getSimpleValueType(); SrcVT = MVT::getVectorVT(SrcVT.getScalarType(), NumOps * SrcVT.getVectorNumElements()); return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, LHS), - DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, RHS)); + ConcatSubOperand(SrcVT, Ops, 0), + ConcatSubOperand(SrcVT, Ops, 1)); } break; case X86ISD::HADD: @@ -49767,17 +49748,12 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, case X86ISD::PACKUS: if (!IsSplat && VT.is256BitVector() && (VT.isFloatingPoint() || Subtarget.hasInt256())) { - SmallVector LHS, RHS; - for (unsigned i = 0; i != NumOps; ++i) { - LHS.push_back(Ops[i].getOperand(0)); - RHS.push_back(Ops[i].getOperand(1)); - } MVT SrcVT = Op0.getOperand(0).getSimpleValueType(); SrcVT = MVT::getVectorVT(SrcVT.getScalarType(), NumOps * SrcVT.getVectorNumElements()); return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, LHS), - DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, RHS)); + ConcatSubOperand(SrcVT, Ops, 0), + ConcatSubOperand(SrcVT, Ops, 1)); } break; case X86ISD::PALIGNR: @@ -49787,15 +49763,9 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, llvm::all_of(Ops, [Op0](SDValue Op) { return Op0.getOperand(2) == Op.getOperand(2); })) { - SmallVector LHS, RHS; - for (unsigned i = 0; i != NumOps; ++i) { - LHS.push_back(Ops[i].getOperand(0)); - RHS.push_back(Ops[i].getOperand(1)); - } return DAG.getNode(Op0.getOpcode(), DL, VT, - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LHS), - DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, RHS), - Op0.getOperand(2)); + ConcatSubOperand(VT, Ops, 0), + ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2)); } break; }