diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h index 9e24ae7d1b431..f431e1d8a38ef 100644 --- a/llvm/include/llvm/IR/AbstractCallSite.h +++ b/llvm/include/llvm/IR/AbstractCallSite.h @@ -137,7 +137,7 @@ class AbstractCallSite { /// Return true if @p U is the use that defines the callee of this ACS. bool isCallee(const Use *U) const { - if (isDirectCall()) + if (!isCallbackCall()) return CB->isCallee(U); assert(!CI.ParameterEncoding.empty() && @@ -154,7 +154,7 @@ class AbstractCallSite { /// Return the number of parameters of the callee. unsigned getNumArgOperands() const { - if (isDirectCall()) + if (!isCallbackCall()) return CB->arg_size(); // Subtract 1 for the callee encoding. return CI.ParameterEncoding.size() - 1; @@ -169,7 +169,7 @@ class AbstractCallSite { /// Return the operand index of the underlying instruction associated with /// the function parameter number @p ArgNo or -1 if there is none. int getCallArgOperandNo(unsigned ArgNo) const { - if (isDirectCall()) + if (!isCallbackCall()) return ArgNo; // Add 1 for the callee encoding. return CI.ParameterEncoding[ArgNo + 1]; @@ -183,7 +183,7 @@ class AbstractCallSite { /// Return the operand of the underlying instruction associated with the /// function parameter number @p ArgNo or nullptr if there is none. Value *getCallArgOperand(unsigned ArgNo) const { - if (isDirectCall()) + if (!isCallbackCall()) return CB->getArgOperand(ArgNo); // Add 1 for the callee encoding. return CI.ParameterEncoding[ArgNo + 1] >= 0 @@ -210,7 +210,7 @@ class AbstractCallSite { /// Return the pointer to function that is being called. Value *getCalledOperand() const { - if (isDirectCall()) + if (!isCallbackCall()) return CB->getCalledOperand(); return CB->getArgOperand(getCallArgOperandNoForCallee()); } diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp index ddb10911ad028..c30515a93b339 100644 --- a/llvm/unittests/IR/AbstractCallSiteTest.cpp +++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp @@ -53,3 +53,56 @@ TEST(AbstractCallSite, CallbackCall) { EXPECT_TRUE(ACS.isCallee(CallbackUse)); EXPECT_EQ(ACS.getCalledFunction(), Callback); } + +TEST(AbstractCallSite, DirectCall) { + LLVMContext C; + + const char *IR = "declare void @bar()\n" + "define void @foo() {\n" + " call void @bar()\n" + " ret void\n" + "}\n"; + + std::unique_ptr M = parseIR(C, IR); + ASSERT_TRUE(M); + + Function *Callee = M->getFunction("bar"); + ASSERT_NE(Callee, nullptr); + + const Use *DirectCallUse = Callee->getSingleUndroppableUse(); + ASSERT_NE(DirectCallUse, nullptr); + + AbstractCallSite ACS(DirectCallUse); + EXPECT_TRUE(ACS); + EXPECT_TRUE(ACS.isDirectCall()); + EXPECT_TRUE(ACS.isCallee(DirectCallUse)); + EXPECT_EQ(ACS.getCalledFunction(), Callee); +} + +TEST(AbstractCallSite, IndirectCall) { + LLVMContext C; + + const char *IR = "define void @foo(ptr %0) {\n" + " call void %0()\n" + " ret void\n" + "}\n"; + + std::unique_ptr M = parseIR(C, IR); + ASSERT_TRUE(M); + + Function *Fun = M->getFunction("foo"); + ASSERT_NE(Fun, nullptr); + + Argument *ArgAsCallee = Fun->getArg(0); + ASSERT_NE(ArgAsCallee, nullptr); + + const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse(); + ASSERT_NE(IndCallUse, nullptr); + + AbstractCallSite ACS(IndCallUse); + EXPECT_TRUE(ACS); + EXPECT_TRUE(ACS.isIndirectCall()); + EXPECT_TRUE(ACS.isCallee(IndCallUse)); + EXPECT_EQ(ACS.getCalledFunction(), nullptr); + EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee); +}