diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index fd98b9e7b753b..cfc4671eaa0e4 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -18566,19 +18566,27 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { } } - // Combine: - // (extract_subvec (concat V1, V2, ...), i) - // Into: - // Vi if possible - // Only operand 0 is checked as 'concat' assumes all inputs of the same - // type. - if (V.getOpcode() == ISD::CONCAT_VECTORS && isa(Index) && - V.getOperand(0).getValueType() == NVT) { - unsigned Idx = N->getConstantOperandVal(1); - unsigned NumElems = NVT.getVectorNumElements(); - assert((Idx % NumElems) == 0 && - "IDX in concat is not a multiple of the result vector length."); - return V->getOperand(Idx / NumElems); + if (V.getOpcode() == ISD::CONCAT_VECTORS && isa(Index)) { + EVT ConcatSrcVT = V.getOperand(0).getValueType(); + assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() && + "Concat and extract subvector do not change element type"); + + unsigned ExtIdx = N->getConstantOperandVal(1); + unsigned ExtNumElts = NVT.getVectorNumElements(); + assert(ExtIdx % ExtNumElts == 0 && + "Extract index is not a multiple of the input vector length."); + + unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorNumElements(); + unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts; + + // If the concatenated source types match this extract, it's a direct + // simplification: + // extract_subvec (concat V1, V2, ...), i --> Vi + if (ConcatSrcNumElts == ExtNumElts) + return V.getOperand(ConcatOpIdx); + + // TODO: Handle the case where the concat operands are larger than the + // result of this extract by extracting directly from a concat op. } V = peekThroughBitcasts(V);