diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 53291a931530f..db0f6dea254e8 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1799,6 +1799,9 @@ class LLVM_ABI_FOR_TEST VPWidenGEPRecipe : public VPRecipeWithIRFlags { VP_CLASSOF_IMPL(VPDef::VPWidenGEPSC) + /// This recipe generates a GEP instruction. + unsigned getOpcode() const { return Instruction::GetElementPtr; } + /// Generate the gep nodes. void execute(VPTransformState &State) override; diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 109156c1469c5..b3735786585cf 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -252,10 +252,9 @@ struct Recipe_match { static bool matchRecipeAndOpcode(const VPRecipeBase *R) { auto *DefR = dyn_cast(R); // Check for recipes that do not have opcodes. - if constexpr (std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) return DefR; else return DefR && DefR->getOpcode() == Opcode; @@ -524,15 +523,24 @@ m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) { } template -using GEPLikeRecipe_match = +using GEPLikeRecipe_match = match_combine_or< Recipe_match, Instruction::GetElementPtr, - /*Commutative*/ false, VPWidenRecipe, VPReplicateRecipe, - VPWidenGEPRecipe, VPInstruction>; + /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>, + match_combine_or< + VPInstruction_match, + VPInstruction_match>>; template inline GEPLikeRecipe_match m_GetElementPtr(const Op0_t &Op0, const Op1_t &Op1) { - return GEPLikeRecipe_match(Op0, Op1); + return m_CombineOr( + Recipe_match, Instruction::GetElementPtr, + /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>( + Op0, Op1), + m_CombineOr( + VPInstruction_match(Op0, Op1), + VPInstruction_match(Op0, + Op1))); } template diff --git a/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp index e38b4fad80b0e..582094bed3ef7 100644 --- a/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp @@ -51,5 +51,29 @@ TEST_F(VPPatternMatchTest, ScalarIVSteps) { m_SpecificInt(2), m_Specific(VF)))); } +TEST_F(VPPatternMatchTest, GetElementPtr) { + VPlan &Plan = getPlan(); + VPBasicBlock *VPBB = Plan.createVPBasicBlock("entry"); + VPBuilder Builder(VPBB); + + IntegerType *I64Ty = IntegerType::get(C, 64); + VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1)); + VPValue *Two = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 2)); + VPValue *Ptr = + Plan.getOrAddLiveIn(Constant::getNullValue(PointerType::get(C, 0))); + + VPInstruction *PtrAdd = Builder.createPtrAdd(Ptr, One); + VPInstruction *WidePtrAdd = Builder.createWidePtrAdd(Ptr, Two); + + using namespace VPlanPatternMatch; + ASSERT_TRUE( + match(PtrAdd, m_GetElementPtr(m_Specific(Ptr), m_SpecificInt(1)))); + ASSERT_FALSE( + match(PtrAdd, m_GetElementPtr(m_Specific(Ptr), m_SpecificInt(2)))); + ASSERT_TRUE( + match(WidePtrAdd, m_GetElementPtr(m_Specific(Ptr), m_SpecificInt(2)))); + ASSERT_FALSE( + match(WidePtrAdd, m_GetElementPtr(m_Specific(Ptr), m_SpecificInt(1)))); +} } // namespace } // namespace llvm