diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 90edaf3ef5471..b283d4ddd907a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7338,16 +7338,23 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, Op.getValueType().getVectorElementCount() == NumElts; }; - auto IsBuildVectorSplatVectorOrUndef = [](const SDValue &Op) { + // UNDEF: folds to undef + // BUILD_VECTOR: may have constant elements + // SPLAT_VECTOR: could be a splat of a constant + // INSERT_SUBVECTOR: could be inserting a constant splat into an undef vector + // - This pattern occurs when a fixed-length vector splat is inserted into + // a scalable vector + auto VectorOpMayConstantFold = [](const SDValue &Op) { return Op.isUndef() || Op.getOpcode() == ISD::CONDCODE || Op.getOpcode() == ISD::BUILD_VECTOR || - Op.getOpcode() == ISD::SPLAT_VECTOR; + Op.getOpcode() == ISD::SPLAT_VECTOR || + Op.getOpcode() == ISD::INSERT_SUBVECTOR; }; // All operands must be vector types with the same number of elements as // the result type and must be either UNDEF or a build/splat vector // or UNDEF scalars. - if (!llvm::all_of(Ops, IsBuildVectorSplatVectorOrUndef) || + if (!llvm::all_of(Ops, VectorOpMayConstantFold) || !llvm::all_of(Ops, IsScalarOrSameVectorSize)) return SDValue(); @@ -7374,14 +7381,28 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, // a combination of BUILD_VECTOR and SPLAT_VECTOR. unsigned NumVectorElts = NumElts.isScalable() ? 1 : NumElts.getFixedValue(); + // Preprocess insert_subvector to avoid repeatedly matching the splat. + SmallVector PreprocessedOps; + for (SDValue Op : Ops) { + if (Op.getOpcode() == ISD::INSERT_SUBVECTOR) { + // match: `insert_subvector undef, (splat X), N2` as `splat X` + SDValue N0 = Op.getOperand(0); + auto *BV = dyn_cast(Op.getOperand(1)); + if (!N0.isUndef() || !BV || !(Op = BV->getSplatValue())) + return SDValue(); + } + PreprocessedOps.push_back(Op); + } + // Constant fold each scalar lane separately. SmallVector ScalarResults; for (unsigned I = 0; I != NumVectorElts; I++) { SmallVector ScalarOps; - for (SDValue Op : Ops) { + for (SDValue Op : PreprocessedOps) { EVT InSVT = Op.getValueType().getScalarType(); if (Op.getOpcode() != ISD::BUILD_VECTOR && - Op.getOpcode() != ISD::SPLAT_VECTOR) { + Op.getOpcode() != ISD::SPLAT_VECTOR && + Op.getOpcode() != ISD::INSERT_SUBVECTOR) { if (Op.isUndef()) ScalarOps.push_back(getUNDEF(InSVT)); else @@ -7389,8 +7410,13 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, continue; } + // insert_subvector has been preprocessed, so if it was of the form + // `insert_subvector undef, (splat X), N2`, it has been replaced with the + // splat value (X). SDValue ScalarOp = - Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I); + Op.getOpcode() == ISD::INSERT_SUBVECTOR + ? Op + : Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I); EVT ScalarVT = ScalarOp.getValueType(); // Build vector (integer) scalar operands may need implicit diff --git a/llvm/test/CodeGen/AArch64/fixed-subvector-insert-into-scalable.ll b/llvm/test/CodeGen/AArch64/fixed-subvector-insert-into-scalable.ll new file mode 100644 index 0000000000000..8758f5a4e244d --- /dev/null +++ b/llvm/test/CodeGen/AArch64/fixed-subvector-insert-into-scalable.ll @@ -0,0 +1,46 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +define @insert_div() { +; CHECK-LABEL: insert_div: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z0.s, #3 // =0x3 +; CHECK-NEXT: ret +entry: + %0 = tail call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 9), i64 0) + %div = udiv %0, splat (i32 3) + ret %div +} + +define @insert_mul() { +; CHECK-LABEL: insert_mul: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z0.s, #7 // =0x7 +; CHECK-NEXT: ret +entry: + %0 = tail call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 1), i64 0) + %mul = mul %0, splat (i32 7) + ret %mul +} + +define @insert_add() { +; CHECK-LABEL: insert_add: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z0.s, #16 // =0x10 +; CHECK-NEXT: ret +entry: + %0 = tail call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 5), i64 0) + %add = add %0, splat (i32 11) + ret %add +} + +define @insert_sub() { +; CHECK-LABEL: insert_sub: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: ret +entry: + %0 = tail call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 11), i64 0) + %sub = add %0, splat (i32 -11) + ret %sub +}