-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[AbstractCallSite] Handle Indirect Calls Properly #163003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-ir Author: Kunqiu Chen (Camsyn) ChangesThis patch fixes a bug in ProblemWhen the underlying call is indirect, some APIs of /// Return the function being called if this is a direct call, otherwise
/// return null (if it's an indirect call).
Function *getCalledFunction() const; This violates the intended contract of Expected Behavior
It should gracefully handle indirect calls. Why It Went UnnoticedCurrently, LLVM does not contain any in-tree use cases where Motivation
FixThis pull request updates the logic in the The main change is that several methods now check Full diff: https://github.com/llvm/llvm-project/pull/163003.diff 2 Files Affected:
diff --git a/llvm/include/llvm/IR/AbstractCallSite.h b/llvm/include/llvm/IR/AbstractCallSite.h
index 9e24ae7d1b431..f431e1d8a38ef 100644
--- a/llvm/include/llvm/IR/AbstractCallSite.h
+++ b/llvm/include/llvm/IR/AbstractCallSite.h
@@ -137,7 +137,7 @@ class AbstractCallSite {
/// Return true if @p U is the use that defines the callee of this ACS.
bool isCallee(const Use *U) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->isCallee(U);
assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
/// Return the number of parameters of the callee.
unsigned getNumArgOperands() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->arg_size();
// Subtract 1 for the callee encoding.
return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
/// Return the operand index of the underlying instruction associated with
/// the function parameter number @p ArgNo or -1 if there is none.
int getCallArgOperandNo(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return ArgNo;
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
/// Return the operand of the underlying instruction associated with the
/// function parameter number @p ArgNo or nullptr if there is none.
Value *getCallArgOperand(unsigned ArgNo) const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getArgOperand(ArgNo);
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
/// Return the pointer to function that is being called.
Value *getCalledOperand() const {
- if (isDirectCall())
+ if (!isCallbackCall())
return CB->getCalledOperand();
return CB->getArgOperand(getCallArgOperandNoForCallee());
}
diff --git a/llvm/unittests/IR/AbstractCallSiteTest.cpp b/llvm/unittests/IR/AbstractCallSiteTest.cpp
index ddb10911ad028..c30515a93b339 100644
--- a/llvm/unittests/IR/AbstractCallSiteTest.cpp
+++ b/llvm/unittests/IR/AbstractCallSiteTest.cpp
@@ -53,3 +53,56 @@ TEST(AbstractCallSite, CallbackCall) {
EXPECT_TRUE(ACS.isCallee(CallbackUse));
EXPECT_EQ(ACS.getCalledFunction(), Callback);
}
+
+TEST(AbstractCallSite, DirectCall) {
+ LLVMContext C;
+
+ const char *IR = "declare void @bar()\n"
+ "define void @foo() {\n"
+ " call void @bar()\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Callee = M->getFunction("bar");
+ ASSERT_NE(Callee, nullptr);
+
+ const Use *DirectCallUse = Callee->getSingleUndroppableUse();
+ ASSERT_NE(DirectCallUse, nullptr);
+
+ AbstractCallSite ACS(DirectCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isDirectCall());
+ EXPECT_TRUE(ACS.isCallee(DirectCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), Callee);
+}
+
+TEST(AbstractCallSite, IndirectCall) {
+ LLVMContext C;
+
+ const char *IR = "define void @foo(ptr %0) {\n"
+ " call void %0()\n"
+ " ret void\n"
+ "}\n";
+
+ std::unique_ptr<Module> M = parseIR(C, IR);
+ ASSERT_TRUE(M);
+
+ Function *Fun = M->getFunction("foo");
+ ASSERT_NE(Fun, nullptr);
+
+ Argument *ArgAsCallee = Fun->getArg(0);
+ ASSERT_NE(ArgAsCallee, nullptr);
+
+ const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
+ ASSERT_NE(IndCallUse, nullptr);
+
+ AbstractCallSite ACS(IndCallUse);
+ EXPECT_TRUE(ACS);
+ EXPECT_TRUE(ACS.isIndirectCall());
+ EXPECT_TRUE(ACS.isCallee(IndCallUse));
+ EXPECT_EQ(ACS.getCalledFunction(), nullptr);
+ EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kindly rewrite the PR description in your own words.
Sorry, I originally used an LLM to polish the PR description, thinking it would make the text more structured and easier to read. |
This patch fixes a bug in
AbstractCallSite
where indirect calls cause unexpected behavior and assertion failures.Problem
When the underlying call is indirect, some APIs of
AbstractCallSite
behave unexpectedly.E.g.,
AbstractCallSite::getCalledFunction()
currently triggers an assertion failure, instead of returningnullptr
as documented:Actual unexpected assertion failure:
This violates the intended contract of
AbstractCallSite
and makes it unsafe to use as a general call-site abstraction, which should be able to represent all THREE types of call, i.e., direct call, indirect call, and callback call.How to Fix
The key to the problem lies in dealing with branches of different calls: the original implementation mistakenly used whether it is a direct call as a branch point, but in fact, we should use whether it is a callback call as a branch point, because direct calls and indirect calls can directly use the abstraction of
CallBase
, while callback calls need to be handled separately.The main change of this PR is that several methods now check
if (!isCallbackCall())
rather than justif (isDirectCall())
, ensuring correct behavior for indirect call sites.Moreover, this PR adds 2 unit tests for direct call type and indirect call type.