diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 99f70b101c2ed..e3ec7e1764da7 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -198,6 +198,26 @@ struct constantexpr_match { /// expression. inline constantexpr_match m_ConstantExpr() { return constantexpr_match(); } +template struct Splat_match { + SubPattern_t SubPattern; + Splat_match(const SubPattern_t &SP) : SubPattern(SP) {} + + template bool match(OpTy *V) const { + if (auto *C = dyn_cast(V)) { + auto *Splat = C->getSplatValue(); + return Splat ? SubPattern.match(Splat) : false; + } + // TODO: Extend to other cases (e.g. shufflevectors). + return false; + } +}; + +/// Match a constant splat. TODO: Extend this to non-constant splats. +template +inline Splat_match m_ConstantSplat(const T &SubPattern) { + return SubPattern; +} + /// Match an arbitrary basic block value and ignore it. inline class_match m_BasicBlock() { return class_match(); @@ -2925,6 +2945,12 @@ inline typename m_Intrinsic_Ty::Ty m_VecReverse(const Opnd0 &Op0) { return m_Intrinsic(Op0); } +template +inline typename m_Intrinsic_Ty::Ty +m_VectorInsert(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2) { + return m_Intrinsic(Op0, Op1, Op2); +} + //===----------------------------------------------------------------------===// // Matchers for two-operands operators with the operators in either order // diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 3f11cae143b81..becc1888152d7 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2323,6 +2323,18 @@ Constant *InstCombinerImpl::unshuffleConstant(ArrayRef ShMask, Constant *C, return ConstantVector::get(NewVecC); } +// Get the result of `Vector Op Splat` (or Splat Op Vector if \p SplatLHS). +static Constant *constantFoldBinOpWithSplat(unsigned Opcode, Constant *Vector, + Constant *Splat, bool SplatLHS, + const DataLayout &DL) { + ElementCount EC = cast(Vector->getType())->getElementCount(); + Constant *LHS = ConstantVector::getSplat(EC, Splat); + Constant *RHS = Vector; + if (!SplatLHS) + std::swap(LHS, RHS); + return ConstantFoldBinaryOpOperands(Opcode, LHS, RHS, DL); +} + Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa(Inst.getType())) return nullptr; @@ -2334,6 +2346,37 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { assert(cast(RHS->getType())->getElementCount() == cast(Inst.getType())->getElementCount()); + auto foldConstantsThroughSubVectorInsertSplat = + [&](Value *MaybeSubVector, Value *MaybeSplat, + bool SplatLHS) -> Instruction * { + Value *Idx; + Constant *Splat, *SubVector, *Dest; + if (!match(MaybeSplat, m_ConstantSplat(m_Constant(Splat))) || + !match(MaybeSubVector, + m_VectorInsert(m_Constant(Dest), m_Constant(SubVector), + m_Value(Idx)))) + return nullptr; + SubVector = + constantFoldBinOpWithSplat(Opcode, SubVector, Splat, SplatLHS, DL); + Dest = constantFoldBinOpWithSplat(Opcode, Dest, Splat, SplatLHS, DL); + if (!SubVector || !Dest) + return nullptr; + auto *InsertVector = + Builder.CreateInsertVector(Dest->getType(), Dest, SubVector, Idx); + return replaceInstUsesWith(Inst, InsertVector); + }; + + // If one operand is a constant splat and the other operand is a + // `vector.insert` where both the destination and subvector are constant, + // apply the operation to both the destination and subvector, returning a new + // constant `vector.insert`. This helps constant folding for scalable vectors. + if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat( + /*MaybeSubVector=*/LHS, /*MaybeSplat=*/RHS, /*SplatLHS=*/false)) + return Folded; + if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat( + /*MaybeSubVector=*/RHS, /*MaybeSplat=*/LHS, /*SplatLHS=*/true)) + return Folded; + // If both operands of the binop are vector concatenations, then perform the // narrow binop on each pair of the source operands followed by concatenation // of the results. diff --git a/llvm/test/Transforms/InstCombine/constant-vector-insert.ll b/llvm/test/Transforms/InstCombine/constant-vector-insert.ll new file mode 100644 index 0000000000000..268854054bd7f --- /dev/null +++ b/llvm/test/Transforms/InstCombine/constant-vector-insert.ll @@ -0,0 +1,156 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=instcombine %s | FileCheck %s +; RUN: opt -S -passes=instcombine %s \ +; RUN: -use-constant-int-for-fixed-length-splat \ +; RUN -use-constant-fp-for-fixed-length-splat \ +; RUN: -use-constant-int-for-scalable-splat \ +; RUN: -use-constant-fp-for-scalable-splat | FileCheck %s + +define @insert_div() { +; CHECK-LABEL: @insert_div( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[DIV:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 3), i64 0) +; CHECK-NEXT: ret [[DIV]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 9), i64 0) + %div = udiv %0, splat (i32 3) + ret %div +} + +define @insert_div_splat_lhs() { +; CHECK-LABEL: @insert_div_splat_lhs( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[DIV:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 5), <4 x i32> splat (i32 2), i64 0) +; CHECK-NEXT: ret [[DIV]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( splat(i32 2), <4 x i32> splat (i32 5), i64 0) + %div = udiv splat (i32 10), %0 + ret %div +} + +define @insert_div_mixed_splat() { +; CHECK-LABEL: @insert_div_mixed_splat( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[DIV:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 6), <4 x i32> splat (i32 3), i64 0) +; CHECK-NEXT: ret [[DIV]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 18), <4 x i32> splat (i32 9), i64 0) + %div = udiv %0, splat (i32 3) + ret %div +} + +define @insert_mul() { +; CHECK-LABEL: @insert_mul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[MUL:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 7), i64 4) +; CHECK-NEXT: ret [[MUL]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 1), i64 4) + %mul = mul %0, splat (i32 7) + ret %mul +} + +define @insert_add() { +; CHECK-LABEL: @insert_add( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 16), i64 0) +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 5), i64 0) + %add = add %0, splat (i32 11) + ret %add +} + +define @insert_add_non_splat_subvector() { +; CHECK-LABEL: @insert_add_non_splat_subvector( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> , i64 0) +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> , i64 0) + %add = add %0, splat (i32 100) + ret %add +} + +define @insert_add_fp() { +; CHECK-LABEL: @insert_add_fp( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD:%.*]] = call @llvm.vector.insert.nxv4f32.v4f32( splat (float 6.250000e+00), <4 x float> splat (float 5.500000e+00), i64 0) +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv4f32.v4f32( splat(float 1.25), <4 x float> splat (float 0.5), i64 0) + %add = fadd %0, splat (float 5.0) + ret %add +} + +define @insert_add_scalable_subvector() { +; CHECK-LABEL: @insert_add_scalable_subvector( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD:%.*]] = call @llvm.vector.insert.nxv8i32.nxv4i32( splat (i32 20), splat (i32 -4), i64 0) +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv8i32.nxv4i32( splat(i32 16), splat (i32 -8), i64 0) + %add = add %0, splat (i32 4) + ret %add +} + +define @insert_sub() { +; CHECK-LABEL: @insert_sub( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SUB:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> zeroinitializer, i64 8) +; CHECK-NEXT: ret [[SUB]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( poison, <4 x i32> splat (i32 11), i64 8) + %sub = add %0, splat (i32 -11) + ret %sub +} + +define @insert_and_partially_undef() { +; CHECK-LABEL: @insert_and_partially_undef( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AND:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( zeroinitializer, <4 x i32> splat (i32 4), i64 0) +; CHECK-NEXT: ret [[AND]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( undef, <4 x i32> splat (i32 6), i64 0) + %and = and %0, splat (i32 4) + ret %and +} + +define @insert_fold_chain() { +; CHECK-LABEL: @insert_fold_chain( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 11), <4 x i32> splat (i32 8), i64 0) +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 21), <4 x i32> splat (i32 12), i64 0) + %div = udiv %0, splat (i32 3) + %add = add %div, splat (i32 4) + ret %add +} + +; TODO: This could be folded more. +define @insert_add_both_insert_vector() { +; CHECK-LABEL: @insert_add_both_insert_vector( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 10), <4 x i32> splat (i32 5), i64 0) +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.vector.insert.nxv4i32.v4i32( splat (i32 -1), <4 x i32> splat (i32 2), i64 0) +; CHECK-NEXT: [[ADD:%.*]] = add [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret [[ADD]] +; +entry: + %0 = call @llvm.vector.insert.nxv4i32.v4i32( splat(i32 10), <4 x i32> splat (i32 5), i64 0) + %1 = call @llvm.vector.insert.nxv4i32.v4i32( splat(i32 -1), <4 x i32> splat (i32 2), i64 0) + %add = add %0, %1 + ret %add +}